Page MenuHomeVyOS Platform

nat_cgnat.py
No OneTemporary

Size
17 KB
Referenced Files
None
Subscribers
None

nat_cgnat.py

#!/usr/bin/env python3
#
# Copyright (C) 2024 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import ipaddress
import jmespath
import logging
import os
from sys import exit
from logging.handlers import SysLogHandler
from vyos.config import Config
from vyos.configdict import is_node_changed
from vyos.template import render
from vyos.utils.process import cmd
from vyos.utils.process import run
from vyos import ConfigError
from vyos import airbag
airbag.enable()
nftables_cgnat_config = '/run/nftables-cgnat.nft'
# Logging
logger = logging.getLogger('cgnat')
logger.setLevel(logging.DEBUG)
syslog_handler = SysLogHandler(address="/dev/log")
syslog_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(name)s: %(message)s')
syslog_handler.setFormatter(formatter)
logger.addHandler(syslog_handler)
class IPOperations:
def __init__(self, ip_prefix: str):
self.ip_prefix = ip_prefix
self.ip_network = ipaddress.ip_network(ip_prefix) if '/' in ip_prefix else None
def get_ips_count(self) -> int:
"""Returns the number of IPs in a prefix or range.
Example:
% ip = IPOperations('192.0.2.0/30')
% ip.get_ips_count()
4
% ip = IPOperations('192.0.2.0-192.0.2.2')
% ip.get_ips_count()
3
"""
if '-' in self.ip_prefix:
start_ip, end_ip = self.ip_prefix.split('-')
start_ip = ipaddress.ip_address(start_ip)
end_ip = ipaddress.ip_address(end_ip)
return int(end_ip) - int(start_ip) + 1
elif '/31' in self.ip_prefix:
return 2
elif '/32' in self.ip_prefix:
return 1
else:
return sum(
1
for _ in [self.ip_network.network_address]
+ list(self.ip_network.hosts())
+ [self.ip_network.broadcast_address]
)
def convert_prefix_to_list_ips(self) -> list:
"""Converts a prefix or IP range to a list of IPs including the network and broadcast addresses.
Example:
% ip = IPOperations('192.0.2.0/30')
% ip.convert_prefix_to_list_ips()
['192.0.2.0', '192.0.2.1', '192.0.2.2', '192.0.2.3']
%
% ip = IPOperations('192.0.0.1-192.0.2.5')
% ip.convert_prefix_to_list_ips()
['192.0.2.1', '192.0.2.2', '192.0.2.3', '192.0.2.4', '192.0.2.5']
"""
if '-' in self.ip_prefix:
start_ip, end_ip = self.ip_prefix.split('-')
start_ip = ipaddress.ip_address(start_ip)
end_ip = ipaddress.ip_address(end_ip)
return [
str(ipaddress.ip_address(ip))
for ip in range(int(start_ip), int(end_ip) + 1)
]
elif '/31' in self.ip_prefix:
return [
str(ip)
for ip in [
self.ip_network.network_address,
self.ip_network.broadcast_address,
]
]
elif '/32' in self.ip_prefix:
return [str(self.ip_network.network_address)]
else:
return [
str(ip)
for ip in [self.ip_network.network_address]
+ list(self.ip_network.hosts())
+ [self.ip_network.broadcast_address]
]
def get_prefix_by_ip_range(self) -> list[ipaddress.IPv4Network]:
"""Return the common prefix for the address range
Example:
% ip = IPOperations('100.64.0.1-100.64.0.5')
% ip.get_prefix_by_ip_range()
[IPv4Network('100.64.0.1/32'), IPv4Network('100.64.0.2/31'), IPv4Network('100.64.0.4/31')]
"""
# We do not need to convert the IP range to network
# if it is already in network format
if self.ip_network:
return [self.ip_network]
# Raise an error if the IP range is not in the correct format
if '-' not in self.ip_prefix:
raise ValueError(
'Invalid IP range format. Please provide the IP range in CIDR format or with "-" separator.'
)
# Split the IP range and convert it to IP address objects
range_start, range_end = self.ip_prefix.split('-')
range_start = ipaddress.IPv4Address(range_start)
range_end = ipaddress.IPv4Address(range_end)
# Return the summarized IP networks list
return list(ipaddress.summarize_address_range(range_start, range_end))
def _delete_conntrack_entries(source_prefixes: list[ipaddress.IPv4Network]) -> None:
"""Delete all conntrack entries for the list of prefixes"""
for source_prefix in source_prefixes:
run(f'conntrack -D -s {source_prefix}')
def generate_port_rules(
external_hosts: list,
internal_hosts: list,
port_count: int,
global_port_range: str = '1024-65535',
) -> list:
"""Generates a list of nftables option rules for the batch file.
Args:
external_hosts (list): A list of external host IPs.
internal_hosts (list): A list of internal host IPs.
port_count (int): The number of ports required per host.
global_port_range (str): The global port range to be used. Default is '1024-65535'.
Returns:
list: A list containing two elements:
- proto_map_elements (list): A list of proto map elements.
- other_map_elements (list): A list of other map elements.
"""
rules = []
proto_map_elements = []
other_map_elements = []
start_port, end_port = map(int, global_port_range.split('-'))
total_possible_ports = (end_port - start_port) + 1
# Calculate the required number of ports per host
required_ports_per_host = port_count
current_port = start_port
current_external_index = 0
for internal_host in internal_hosts:
external_host = external_hosts[current_external_index]
next_end_port = current_port + required_ports_per_host - 1
# If the port range exceeds the end_port, move to the next external host
while next_end_port > end_port:
current_external_index = (current_external_index + 1) % len(external_hosts)
external_host = external_hosts[current_external_index]
current_port = start_port
next_end_port = current_port + required_ports_per_host - 1
proto_map_elements.append(
f'{internal_host} : {external_host} . {current_port}-{next_end_port}'
)
other_map_elements.append(f'{internal_host} : {external_host}')
current_port = next_end_port + 1
if current_port > end_port:
current_port = start_port
current_external_index += 1 # Move to the next external host
return [proto_map_elements, other_map_elements]
def get_config(config=None):
if config:
conf = config
else:
conf = Config()
base = ['nat', 'cgnat']
config = conf.get_config_dict(
base,
get_first_key=True,
key_mangling=('-', '_'),
no_tag_node_value_mangle=True,
with_recursive_defaults=True,
)
effective_config = conf.get_config_dict(
base,
get_first_key=True,
key_mangling=('-', '_'),
no_tag_node_value_mangle=True,
effective=True,
)
# Check if the pool configuration has changed
if not conf.exists(base) or is_node_changed(conf, base + ['pool']):
config['delete_conntrack_entries'] = {}
# add running config
if effective_config:
config['effective'] = effective_config
if not conf.exists(base):
config['deleted'] = {}
return config
def verify(config):
# bail out early - looks like removal from running config
if 'deleted' in config:
return None
if 'pool' not in config:
raise ConfigError(f'Pool must be defined!')
if 'rule' not in config:
raise ConfigError(f'Rule must be defined!')
for pool in ('external', 'internal'):
if pool not in config['pool']:
raise ConfigError(f'{pool} pool must be defined!')
for pool_name, pool_config in config['pool'][pool].items():
if 'range' not in pool_config:
raise ConfigError(
f'Range for "{pool} pool {pool_name}" must be defined!'
)
external_pools_query = "keys(pool.external)"
external_pools: list = jmespath.search(external_pools_query, config)
internal_pools_query = "keys(pool.internal)"
internal_pools: list = jmespath.search(internal_pools_query, config)
used_external_pools = {}
used_internal_pools = {}
for rule, rule_config in config['rule'].items():
if 'source' not in rule_config:
raise ConfigError(f'Rule "{rule}" source pool must be defined!')
if 'pool' not in rule_config['source']:
raise ConfigError(f'Rule "{rule}" source pool must be defined!')
if 'translation' not in rule_config:
raise ConfigError(f'Rule "{rule}" translation pool must be defined!')
# Check if pool exists
internal_pool = rule_config['source']['pool']
if internal_pool not in internal_pools:
raise ConfigError(f'Internal pool "{internal_pool}" does not exist!')
external_pool = rule_config['translation']['pool']
if external_pool not in external_pools:
raise ConfigError(f'External pool "{external_pool}" does not exist!')
# Check pool duplication in different rules
if external_pool in used_external_pools:
raise ConfigError(
f'External pool "{external_pool}" is already used in rule '
f'{used_external_pools[external_pool]} and cannot be used in '
f'rule {rule}!'
)
if internal_pool in used_internal_pools:
raise ConfigError(
f'Internal pool "{internal_pool}" is already used in rule '
f'{used_internal_pools[internal_pool]} and cannot be used in '
f'rule {rule}!'
)
used_external_pools[external_pool] = rule
used_internal_pools[internal_pool] = rule
# Check calculation for allocation
external_port_range: str = config['pool']['external'][external_pool]['external_port_range']
external_ip_ranges: list = list(
config['pool']['external'][external_pool]['range']
)
internal_ip_ranges: list = config['pool']['internal'][internal_pool]['range']
start_port, end_port = map(int, external_port_range.split('-'))
ports_per_range_count: int = (end_port - start_port) + 1
external_list_hosts_count = []
external_list_hosts = []
internal_list_hosts_count = []
internal_list_hosts = []
for ext_range in external_ip_ranges:
# External hosts count
e_count = IPOperations(ext_range).get_ips_count()
external_list_hosts_count.append(e_count)
# External hosts list
e_hosts = IPOperations(ext_range).convert_prefix_to_list_ips()
external_list_hosts.extend(e_hosts)
for int_range in internal_ip_ranges:
# Internal hosts count
i_count = IPOperations(int_range).get_ips_count()
internal_list_hosts_count.append(i_count)
# Internal hosts list
i_hosts = IPOperations(int_range).convert_prefix_to_list_ips()
internal_list_hosts.extend(i_hosts)
external_host_count = sum(external_list_hosts_count)
internal_host_count = sum(internal_list_hosts_count)
ports_per_user: int = int(
config['pool']['external'][external_pool]['per_user_limit']['port']
)
users_per_extip = ports_per_range_count // ports_per_user
max_users = users_per_extip * external_host_count
if internal_host_count > max_users:
raise ConfigError(
f'Rule "{rule}" does not have enough ports available for the '
f'specified parameters'
)
def generate(config):
if 'deleted' in config:
return None
proto_maps = []
other_maps = []
for rule, rule_config in config['rule'].items():
ext_pool_name: str = rule_config['translation']['pool']
int_pool_name: str = rule_config['source']['pool']
# Sort the external ranges by sequence
external_ranges: list = sorted(
config['pool']['external'][ext_pool_name]['range'],
key=lambda r: int(config['pool']['external'][ext_pool_name]['range'][r].get('seq', 999999))
)
internal_ranges: list = [range for range in config['pool']['internal'][int_pool_name]['range']]
external_list_hosts_count = []
external_list_hosts = []
internal_list_hosts_count = []
internal_list_hosts = []
for ext_range in external_ranges:
# External hosts count
e_count = IPOperations(ext_range).get_ips_count()
external_list_hosts_count.append(e_count)
# External hosts list
e_hosts = IPOperations(ext_range).convert_prefix_to_list_ips()
external_list_hosts.extend(e_hosts)
for int_range in internal_ranges:
# Internal hosts count
i_count = IPOperations(int_range).get_ips_count()
internal_list_hosts_count.append(i_count)
# Internal hosts list
i_hosts = IPOperations(int_range).convert_prefix_to_list_ips()
internal_list_hosts.extend(i_hosts)
external_host_count = sum(external_list_hosts_count)
internal_host_count = sum(internal_list_hosts_count)
ports_per_user = int(
jmespath.search(f'pool.external."{ext_pool_name}".per_user_limit.port', config)
)
external_port_range: str = jmespath.search(
f'pool.external."{ext_pool_name}".external_port_range', config
)
rule_proto_maps, rule_other_maps = generate_port_rules(
external_list_hosts, internal_list_hosts, ports_per_user, external_port_range
)
proto_maps.extend(rule_proto_maps)
other_maps.extend(rule_other_maps)
config['proto_map_elements'] = ', '.join(proto_maps)
config['other_map_elements'] = ', '.join(other_maps)
render(nftables_cgnat_config, 'firewall/nftables-cgnat.j2', config)
# dry-run newly generated configuration
tmp = run(f'nft --check --file {nftables_cgnat_config}')
if tmp > 0:
raise ConfigError('Configuration file errors encountered!')
def apply(config):
if 'deleted' in config:
# Cleanup cgnat
cmd('nft delete table ip cgnat')
if os.path.isfile(nftables_cgnat_config):
os.unlink(nftables_cgnat_config)
else:
cmd(f'nft --file {nftables_cgnat_config}')
# Delete conntrack entries
# if the pool configuration has changed
if 'delete_conntrack_entries' in config and 'effective' in config:
# Prepare the list of internal pool prefixes
internal_pool_prefix_list: list[ipaddress.IPv4Network] = []
# Get effective rules configurations
for rule_config in config['effective'].get('rule', {}).values():
# Get effective internal pool configuration
internal_pool = rule_config['source']['pool']
# Find the internal IP ranges for the internal pool
internal_ip_ranges: list[str] = config['effective']['pool']['internal'][
internal_pool
]['range']
# Get the IP prefixes for the internal IP range
for internal_range in internal_ip_ranges:
ip_prefix: list[ipaddress.IPv4Network] = IPOperations(
internal_range
).get_prefix_by_ip_range()
# Add the IP prefixes to the list of all internal pool prefixes
internal_pool_prefix_list += ip_prefix
# Delete required sources for conntrack
_delete_conntrack_entries(internal_pool_prefix_list)
# Logging allocations
if 'log_allocation' in config:
allocations = config['proto_map_elements']
allocations = allocations.split(',')
for allocation in allocations:
try:
# Split based on the delimiters used in the nft data format
internal_host, rest = allocation.split(' : ')
external_host, port_range = rest.split(' . ')
# Log the parsed data
logger.info(
f'Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}')
except ValueError as e:
# Log error message
logger.error(f"Error processing line '{allocation}': {e}")
if __name__ == '__main__':
try:
c = get_config()
verify(c)
generate(c)
apply(c)
except ConfigError as e:
print(e)
exit(1)

File Metadata

Mime Type
text/x-script.python
Expires
Tue, Dec 16, 4:00 AM (1 d, 12 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3060572
Default Alt Text
nat_cgnat.py (17 KB)

Event Timeline