Page Menu
Home
VyOS Platform
Search
Configure Global Search
Log In
Files
F38930340
nat_cgnat.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Flag For Later
Award Token
Size
17 KB
Referenced Files
None
Subscribers
None
nat_cgnat.py
View Options
#!/usr/bin/env python3
#
# Copyright (C) 2024 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import
ipaddress
import
jmespath
import
logging
import
os
from
sys
import
exit
from
logging.handlers
import
SysLogHandler
from
vyos.config
import
Config
from
vyos.configdict
import
is_node_changed
from
vyos.template
import
render
from
vyos.utils.process
import
cmd
from
vyos.utils.process
import
run
from
vyos
import
ConfigError
from
vyos
import
airbag
airbag
.
enable
()
nftables_cgnat_config
=
'/run/nftables-cgnat.nft'
# Logging
logger
=
logging
.
getLogger
(
'cgnat'
)
logger
.
setLevel
(
logging
.
DEBUG
)
syslog_handler
=
SysLogHandler
(
address
=
"/dev/log"
)
syslog_handler
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
'
%(name)s
:
%(message)s
'
)
syslog_handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
syslog_handler
)
class
IPOperations
:
def
__init__
(
self
,
ip_prefix
:
str
):
self
.
ip_prefix
=
ip_prefix
self
.
ip_network
=
ipaddress
.
ip_network
(
ip_prefix
)
if
'/'
in
ip_prefix
else
None
def
get_ips_count
(
self
)
->
int
:
"""Returns the number of IPs in a prefix or range.
Example:
% ip = IPOperations('192.0.2.0/30')
% ip.get_ips_count()
4
% ip = IPOperations('192.0.2.0-192.0.2.2')
% ip.get_ips_count()
3
"""
if
'-'
in
self
.
ip_prefix
:
start_ip
,
end_ip
=
self
.
ip_prefix
.
split
(
'-'
)
start_ip
=
ipaddress
.
ip_address
(
start_ip
)
end_ip
=
ipaddress
.
ip_address
(
end_ip
)
return
int
(
end_ip
)
-
int
(
start_ip
)
+
1
elif
'/31'
in
self
.
ip_prefix
:
return
2
elif
'/32'
in
self
.
ip_prefix
:
return
1
else
:
return
sum
(
1
for
_
in
[
self
.
ip_network
.
network_address
]
+
list
(
self
.
ip_network
.
hosts
())
+
[
self
.
ip_network
.
broadcast_address
]
)
def
convert_prefix_to_list_ips
(
self
)
->
list
:
"""Converts a prefix or IP range to a list of IPs including the network and broadcast addresses.
Example:
% ip = IPOperations('192.0.2.0/30')
% ip.convert_prefix_to_list_ips()
['192.0.2.0', '192.0.2.1', '192.0.2.2', '192.0.2.3']
%
% ip = IPOperations('192.0.0.1-192.0.2.5')
% ip.convert_prefix_to_list_ips()
['192.0.2.1', '192.0.2.2', '192.0.2.3', '192.0.2.4', '192.0.2.5']
"""
if
'-'
in
self
.
ip_prefix
:
start_ip
,
end_ip
=
self
.
ip_prefix
.
split
(
'-'
)
start_ip
=
ipaddress
.
ip_address
(
start_ip
)
end_ip
=
ipaddress
.
ip_address
(
end_ip
)
return
[
str
(
ipaddress
.
ip_address
(
ip
))
for
ip
in
range
(
int
(
start_ip
),
int
(
end_ip
)
+
1
)
]
elif
'/31'
in
self
.
ip_prefix
:
return
[
str
(
ip
)
for
ip
in
[
self
.
ip_network
.
network_address
,
self
.
ip_network
.
broadcast_address
,
]
]
elif
'/32'
in
self
.
ip_prefix
:
return
[
str
(
self
.
ip_network
.
network_address
)]
else
:
return
[
str
(
ip
)
for
ip
in
[
self
.
ip_network
.
network_address
]
+
list
(
self
.
ip_network
.
hosts
())
+
[
self
.
ip_network
.
broadcast_address
]
]
def
get_prefix_by_ip_range
(
self
)
->
list
[
ipaddress
.
IPv4Network
]:
"""Return the common prefix for the address range
Example:
% ip = IPOperations('100.64.0.1-100.64.0.5')
% ip.get_prefix_by_ip_range()
[IPv4Network('100.64.0.1/32'), IPv4Network('100.64.0.2/31'), IPv4Network('100.64.0.4/31')]
"""
# We do not need to convert the IP range to network
# if it is already in network format
if
self
.
ip_network
:
return
[
self
.
ip_network
]
# Raise an error if the IP range is not in the correct format
if
'-'
not
in
self
.
ip_prefix
:
raise
ValueError
(
'Invalid IP range format. Please provide the IP range in CIDR format or with "-" separator.'
)
# Split the IP range and convert it to IP address objects
range_start
,
range_end
=
self
.
ip_prefix
.
split
(
'-'
)
range_start
=
ipaddress
.
IPv4Address
(
range_start
)
range_end
=
ipaddress
.
IPv4Address
(
range_end
)
# Return the summarized IP networks list
return
list
(
ipaddress
.
summarize_address_range
(
range_start
,
range_end
))
def
_delete_conntrack_entries
(
source_prefixes
:
list
[
ipaddress
.
IPv4Network
])
->
None
:
"""Delete all conntrack entries for the list of prefixes"""
for
source_prefix
in
source_prefixes
:
run
(
f
'conntrack -D -s {source_prefix}'
)
def
generate_port_rules
(
external_hosts
:
list
,
internal_hosts
:
list
,
port_count
:
int
,
global_port_range
:
str
=
'1024-65535'
,
)
->
list
:
"""Generates a list of nftables option rules for the batch file.
Args:
external_hosts (list): A list of external host IPs.
internal_hosts (list): A list of internal host IPs.
port_count (int): The number of ports required per host.
global_port_range (str): The global port range to be used. Default is '1024-65535'.
Returns:
list: A list containing two elements:
- proto_map_elements (list): A list of proto map elements.
- other_map_elements (list): A list of other map elements.
"""
rules
=
[]
proto_map_elements
=
[]
other_map_elements
=
[]
start_port
,
end_port
=
map
(
int
,
global_port_range
.
split
(
'-'
))
total_possible_ports
=
(
end_port
-
start_port
)
+
1
# Calculate the required number of ports per host
required_ports_per_host
=
port_count
current_port
=
start_port
current_external_index
=
0
for
internal_host
in
internal_hosts
:
external_host
=
external_hosts
[
current_external_index
]
next_end_port
=
current_port
+
required_ports_per_host
-
1
# If the port range exceeds the end_port, move to the next external host
while
next_end_port
>
end_port
:
current_external_index
=
(
current_external_index
+
1
)
%
len
(
external_hosts
)
external_host
=
external_hosts
[
current_external_index
]
current_port
=
start_port
next_end_port
=
current_port
+
required_ports_per_host
-
1
proto_map_elements
.
append
(
f
'{internal_host} : {external_host} . {current_port}-{next_end_port}'
)
other_map_elements
.
append
(
f
'{internal_host} : {external_host}'
)
current_port
=
next_end_port
+
1
if
current_port
>
end_port
:
current_port
=
start_port
current_external_index
+=
1
# Move to the next external host
return
[
proto_map_elements
,
other_map_elements
]
def
get_config
(
config
=
None
):
if
config
:
conf
=
config
else
:
conf
=
Config
()
base
=
[
'nat'
,
'cgnat'
]
config
=
conf
.
get_config_dict
(
base
,
get_first_key
=
True
,
key_mangling
=
(
'-'
,
'_'
),
no_tag_node_value_mangle
=
True
,
with_recursive_defaults
=
True
,
)
effective_config
=
conf
.
get_config_dict
(
base
,
get_first_key
=
True
,
key_mangling
=
(
'-'
,
'_'
),
no_tag_node_value_mangle
=
True
,
effective
=
True
,
)
# Check if the pool configuration has changed
if
not
conf
.
exists
(
base
)
or
is_node_changed
(
conf
,
base
+
[
'pool'
]):
config
[
'delete_conntrack_entries'
]
=
{}
# add running config
if
effective_config
:
config
[
'effective'
]
=
effective_config
if
not
conf
.
exists
(
base
):
config
[
'deleted'
]
=
{}
return
config
def
verify
(
config
):
# bail out early - looks like removal from running config
if
'deleted'
in
config
:
return
None
if
'pool'
not
in
config
:
raise
ConfigError
(
f
'Pool must be defined!'
)
if
'rule'
not
in
config
:
raise
ConfigError
(
f
'Rule must be defined!'
)
for
pool
in
(
'external'
,
'internal'
):
if
pool
not
in
config
[
'pool'
]:
raise
ConfigError
(
f
'{pool} pool must be defined!'
)
for
pool_name
,
pool_config
in
config
[
'pool'
][
pool
]
.
items
():
if
'range'
not
in
pool_config
:
raise
ConfigError
(
f
'Range for "{pool} pool {pool_name}" must be defined!'
)
external_pools_query
=
"keys(pool.external)"
external_pools
:
list
=
jmespath
.
search
(
external_pools_query
,
config
)
internal_pools_query
=
"keys(pool.internal)"
internal_pools
:
list
=
jmespath
.
search
(
internal_pools_query
,
config
)
used_external_pools
=
{}
used_internal_pools
=
{}
for
rule
,
rule_config
in
config
[
'rule'
]
.
items
():
if
'source'
not
in
rule_config
:
raise
ConfigError
(
f
'Rule "{rule}" source pool must be defined!'
)
if
'pool'
not
in
rule_config
[
'source'
]:
raise
ConfigError
(
f
'Rule "{rule}" source pool must be defined!'
)
if
'translation'
not
in
rule_config
:
raise
ConfigError
(
f
'Rule "{rule}" translation pool must be defined!'
)
# Check if pool exists
internal_pool
=
rule_config
[
'source'
][
'pool'
]
if
internal_pool
not
in
internal_pools
:
raise
ConfigError
(
f
'Internal pool "{internal_pool}" does not exist!'
)
external_pool
=
rule_config
[
'translation'
][
'pool'
]
if
external_pool
not
in
external_pools
:
raise
ConfigError
(
f
'External pool "{external_pool}" does not exist!'
)
# Check pool duplication in different rules
if
external_pool
in
used_external_pools
:
raise
ConfigError
(
f
'External pool "{external_pool}" is already used in rule '
f
'{used_external_pools[external_pool]} and cannot be used in '
f
'rule {rule}!'
)
if
internal_pool
in
used_internal_pools
:
raise
ConfigError
(
f
'Internal pool "{internal_pool}" is already used in rule '
f
'{used_internal_pools[internal_pool]} and cannot be used in '
f
'rule {rule}!'
)
used_external_pools
[
external_pool
]
=
rule
used_internal_pools
[
internal_pool
]
=
rule
# Check calculation for allocation
external_port_range
:
str
=
config
[
'pool'
][
'external'
][
external_pool
][
'external_port_range'
]
external_ip_ranges
:
list
=
list
(
config
[
'pool'
][
'external'
][
external_pool
][
'range'
]
)
internal_ip_ranges
:
list
=
config
[
'pool'
][
'internal'
][
internal_pool
][
'range'
]
start_port
,
end_port
=
map
(
int
,
external_port_range
.
split
(
'-'
))
ports_per_range_count
:
int
=
(
end_port
-
start_port
)
+
1
external_list_hosts_count
=
[]
external_list_hosts
=
[]
internal_list_hosts_count
=
[]
internal_list_hosts
=
[]
for
ext_range
in
external_ip_ranges
:
# External hosts count
e_count
=
IPOperations
(
ext_range
)
.
get_ips_count
()
external_list_hosts_count
.
append
(
e_count
)
# External hosts list
e_hosts
=
IPOperations
(
ext_range
)
.
convert_prefix_to_list_ips
()
external_list_hosts
.
extend
(
e_hosts
)
for
int_range
in
internal_ip_ranges
:
# Internal hosts count
i_count
=
IPOperations
(
int_range
)
.
get_ips_count
()
internal_list_hosts_count
.
append
(
i_count
)
# Internal hosts list
i_hosts
=
IPOperations
(
int_range
)
.
convert_prefix_to_list_ips
()
internal_list_hosts
.
extend
(
i_hosts
)
external_host_count
=
sum
(
external_list_hosts_count
)
internal_host_count
=
sum
(
internal_list_hosts_count
)
ports_per_user
:
int
=
int
(
config
[
'pool'
][
'external'
][
external_pool
][
'per_user_limit'
][
'port'
]
)
users_per_extip
=
ports_per_range_count
//
ports_per_user
max_users
=
users_per_extip
*
external_host_count
if
internal_host_count
>
max_users
:
raise
ConfigError
(
f
'Rule "{rule}" does not have enough ports available for the '
f
'specified parameters'
)
def
generate
(
config
):
if
'deleted'
in
config
:
return
None
proto_maps
=
[]
other_maps
=
[]
for
rule
,
rule_config
in
config
[
'rule'
]
.
items
():
ext_pool_name
:
str
=
rule_config
[
'translation'
][
'pool'
]
int_pool_name
:
str
=
rule_config
[
'source'
][
'pool'
]
# Sort the external ranges by sequence
external_ranges
:
list
=
sorted
(
config
[
'pool'
][
'external'
][
ext_pool_name
][
'range'
],
key
=
lambda
r
:
int
(
config
[
'pool'
][
'external'
][
ext_pool_name
][
'range'
][
r
]
.
get
(
'seq'
,
999999
))
)
internal_ranges
:
list
=
[
range
for
range
in
config
[
'pool'
][
'internal'
][
int_pool_name
][
'range'
]]
external_list_hosts_count
=
[]
external_list_hosts
=
[]
internal_list_hosts_count
=
[]
internal_list_hosts
=
[]
for
ext_range
in
external_ranges
:
# External hosts count
e_count
=
IPOperations
(
ext_range
)
.
get_ips_count
()
external_list_hosts_count
.
append
(
e_count
)
# External hosts list
e_hosts
=
IPOperations
(
ext_range
)
.
convert_prefix_to_list_ips
()
external_list_hosts
.
extend
(
e_hosts
)
for
int_range
in
internal_ranges
:
# Internal hosts count
i_count
=
IPOperations
(
int_range
)
.
get_ips_count
()
internal_list_hosts_count
.
append
(
i_count
)
# Internal hosts list
i_hosts
=
IPOperations
(
int_range
)
.
convert_prefix_to_list_ips
()
internal_list_hosts
.
extend
(
i_hosts
)
external_host_count
=
sum
(
external_list_hosts_count
)
internal_host_count
=
sum
(
internal_list_hosts_count
)
ports_per_user
=
int
(
jmespath
.
search
(
f
'pool.external."{ext_pool_name}".per_user_limit.port'
,
config
)
)
external_port_range
:
str
=
jmespath
.
search
(
f
'pool.external."{ext_pool_name}".external_port_range'
,
config
)
rule_proto_maps
,
rule_other_maps
=
generate_port_rules
(
external_list_hosts
,
internal_list_hosts
,
ports_per_user
,
external_port_range
)
proto_maps
.
extend
(
rule_proto_maps
)
other_maps
.
extend
(
rule_other_maps
)
config
[
'proto_map_elements'
]
=
', '
.
join
(
proto_maps
)
config
[
'other_map_elements'
]
=
', '
.
join
(
other_maps
)
render
(
nftables_cgnat_config
,
'firewall/nftables-cgnat.j2'
,
config
)
# dry-run newly generated configuration
tmp
=
run
(
f
'nft --check --file {nftables_cgnat_config}'
)
if
tmp
>
0
:
raise
ConfigError
(
'Configuration file errors encountered!'
)
def
apply
(
config
):
if
'deleted'
in
config
:
# Cleanup cgnat
cmd
(
'nft delete table ip cgnat'
)
if
os
.
path
.
isfile
(
nftables_cgnat_config
):
os
.
unlink
(
nftables_cgnat_config
)
else
:
cmd
(
f
'nft --file {nftables_cgnat_config}'
)
# Delete conntrack entries
# if the pool configuration has changed
if
'delete_conntrack_entries'
in
config
and
'effective'
in
config
:
# Prepare the list of internal pool prefixes
internal_pool_prefix_list
:
list
[
ipaddress
.
IPv4Network
]
=
[]
# Get effective rules configurations
for
rule_config
in
config
[
'effective'
]
.
get
(
'rule'
,
{})
.
values
():
# Get effective internal pool configuration
internal_pool
=
rule_config
[
'source'
][
'pool'
]
# Find the internal IP ranges for the internal pool
internal_ip_ranges
:
list
[
str
]
=
config
[
'effective'
][
'pool'
][
'internal'
][
internal_pool
][
'range'
]
# Get the IP prefixes for the internal IP range
for
internal_range
in
internal_ip_ranges
:
ip_prefix
:
list
[
ipaddress
.
IPv4Network
]
=
IPOperations
(
internal_range
)
.
get_prefix_by_ip_range
()
# Add the IP prefixes to the list of all internal pool prefixes
internal_pool_prefix_list
+=
ip_prefix
# Delete required sources for conntrack
_delete_conntrack_entries
(
internal_pool_prefix_list
)
# Logging allocations
if
'log_allocation'
in
config
:
allocations
=
config
[
'proto_map_elements'
]
allocations
=
allocations
.
split
(
','
)
for
allocation
in
allocations
:
try
:
# Split based on the delimiters used in the nft data format
internal_host
,
rest
=
allocation
.
split
(
' : '
)
external_host
,
port_range
=
rest
.
split
(
' . '
)
# Log the parsed data
logger
.
info
(
f
'Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}'
)
except
ValueError
as
e
:
# Log error message
logger
.
error
(
f
"Error processing line '{allocation}': {e}"
)
if
__name__
==
'__main__'
:
try
:
c
=
get_config
()
verify
(
c
)
generate
(
c
)
apply
(
c
)
except
ConfigError
as
e
:
print
(
e
)
exit
(
1
)
File Metadata
Details
Attached
Mime Type
text/x-script.python
Expires
Tue, Dec 16, 4:00 AM (1 d, 12 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3060572
Default Alt Text
nat_cgnat.py (17 KB)
Attached To
Mode
rVYOSONEX vyos-1x
Attached
Detach File
Event Timeline
Log In to Comment