diff --git a/python/vyos/accel_ppp.py b/python/vyos/accel_ppp.py index a6f2ceb52..bae695fc3 100644 --- a/python/vyos/accel_ppp.py +++ b/python/vyos/accel_ppp.py @@ -1,72 +1,70 @@ -#!/usr/bin/env python3 -# # Copyright (C) 2022-2024 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. from vyos.utils.process import rc_cmd def get_server_statistics(accel_statistics, pattern, sep=':') -> dict: import re stat_dict = {'sessions': {}} cpu = re.search(r'cpu(.*)', accel_statistics).group(0) # Find all lines with pattern, for example 'sstp:' data = re.search(rf'{pattern}(.*)', accel_statistics, re.DOTALL).group(0) session_starting = re.search(r'starting(.*)', data).group(0) session_active = re.search(r'active(.*)', data).group(0) for entry in {cpu, session_starting, session_active}: if sep in entry: key, value = entry.split(sep) if key in ['starting', 'active', 'finishing']: stat_dict['sessions'][key] = value.strip() continue if key == 'cpu': stat_dict['cpu_load_percentage'] = int(re.sub(r'%', '', value.strip())) continue stat_dict[key] = value.strip() return stat_dict def accel_cmd(port: int, command: str) -> str: _, output = rc_cmd(f'/usr/bin/accel-cmd -p{port} {command}') return output def accel_out_parse(accel_output: list[str]) -> list[dict[str, str]]: """ Parse accel-cmd show sessions output """ data_list: list[dict[str, str]] = list() field_names: list[str] = list() field_names_unstripped: list[str] = accel_output.pop(0).split('|') for field_name in field_names_unstripped: field_names.append(field_name.strip()) while accel_output: if '|' not in accel_output[0]: accel_output.pop(0) continue current_item: list[str] = accel_output.pop(0).split('|') item_dict: dict[str, str] = {} for field_index in range(len(current_item)): field_name: str = field_names[field_index] field_value: str = current_item[field_index].strip() item_dict[field_name] = field_value data_list.append(item_dict) return data_list diff --git a/python/vyos/cpu.py b/python/vyos/cpu.py index d2e5f6504..12b6285d0 100644 --- a/python/vyos/cpu.py +++ b/python/vyos/cpu.py @@ -1,103 +1,102 @@ -#!/usr/bin/env python3 -# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright (C) 2022-2024 maintainers and contributors # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. """ Retrieves (or at least attempts to retrieve) the total number of real CPU cores installed in a Linux system. The issue of core count is complicated by existence of SMT, e.g. Intel's Hyper Threading. GNU nproc returns the number of LOGICAL cores, which is 2x of the real cores if SMT is enabled. The idea is to find all physical CPUs and add up their core counts. It has special cases for x86_64 and MAY work correctly on other architectures, but nothing is certain. """ import re def _read_cpuinfo(): with open('/proc/cpuinfo', 'r') as f: lines = f.read().strip() return re.split(r'\n+', lines) def _split_line(l): l = l.strip() parts = re.split(r'\s*:\s*', l) return (parts[0], ":".join(parts[1:])) def _find_cpus(cpuinfo_lines): # Make a dict because it's more convenient to work with later, # when we need to find physicall distinct CPUs there. cpus = {} cpu_number = 0 for l in cpuinfo_lines: key, value = _split_line(l) if key == 'processor': cpu_number = value cpus[cpu_number] = {} else: cpus[cpu_number][key] = value return cpus def _find_physical_cpus(): cpus = _find_cpus(_read_cpuinfo()) phys_cpus = {} for num in cpus: if 'physical id' in cpus[num]: # On at least some architectures, CPUs in different sockets # have different 'physical id' field, e.g. on x86_64. phys_id = cpus[num]['physical id'] if phys_id not in phys_cpus: phys_cpus[phys_id] = cpus[num] else: # On other architectures, e.g. on ARM, there's no such field. # We just assume they are different CPUs, # whether single core ones or cores of physical CPUs. phys_cpus[num] = cpus[num] return phys_cpus def get_cpus(): """ Returns a list of /proc/cpuinfo entries that belong to different CPUs. """ cpus_dict = _find_physical_cpus() return list(cpus_dict.values()) def get_core_count(): """ Returns the total number of physical CPU cores (even if Hyper-Threading or another SMT is enabled and has inflated the number of cores in /proc/cpuinfo) """ physical_cpus = _find_physical_cpus() core_count = 0 for num in physical_cpus: # Some architectures, e.g. x86_64, include a field for core count. # Since we found unique physical CPU entries, we can sum their core counts. if 'cpu cores' in physical_cpus[num]: core_count += int(physical_cpus[num]['cpu cores']) else: core_count += 1 return core_count diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py index 0713eb370..946050a82 100644 --- a/python/vyos/firewall.py +++ b/python/vyos/firewall.py @@ -1,662 +1,660 @@ -#!/usr/bin/env python3 -# # Copyright (C) 2021-2024 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import csv import gzip import os import re from pathlib import Path from socket import AF_INET from socket import AF_INET6 from socket import getaddrinfo from time import strftime from vyos.remote import download from vyos.template import is_ipv4 from vyos.template import render from vyos.utils.dict import dict_search_args from vyos.utils.dict import dict_search_recursive from vyos.utils.process import cmd from vyos.utils.process import run # Conntrack def conntrack_required(conf): required_nodes = ['nat', 'nat66', 'load-balancing wan'] for path in required_nodes: if conf.exists(path): return True firewall = conf.get_config_dict(['firewall'], key_mangling=('-', '_'), no_tag_node_value_mangle=True, get_first_key=True) for rules, path in dict_search_recursive(firewall, 'rule'): if any(('state' in rule_conf or 'connection_status' in rule_conf or 'offload_target' in rule_conf) for rule_conf in rules.values()): return True return False # Domain Resolver def fqdn_config_parse(firewall): firewall['ip_fqdn'] = {} firewall['ip6_fqdn'] = {} for domain, path in dict_search_recursive(firewall, 'fqdn'): hook_name = path[1] priority = path[2] fw_name = path[2] rule = path[4] suffix = path[5][0] set_name = f'{hook_name}_{priority}_{rule}_{suffix}' if (path[0] == 'ipv4') and (path[1] == 'forward' or path[1] == 'input' or path[1] == 'output' or path[1] == 'name'): firewall['ip_fqdn'][set_name] = domain elif (path[0] == 'ipv6') and (path[1] == 'forward' or path[1] == 'input' or path[1] == 'output' or path[1] == 'name'): if path[1] == 'name': set_name = f'name6_{priority}_{rule}_{suffix}' firewall['ip6_fqdn'][set_name] = domain def fqdn_resolve(fqdn, ipv6=False): try: res = getaddrinfo(fqdn, None, AF_INET6 if ipv6 else AF_INET) return set(item[4][0] for item in res) except: return None # End Domain Resolver def find_nftables_rule(table, chain, rule_matches=[]): # Find rule in table/chain that matches all criteria and return the handle results = cmd(f'sudo nft --handle list chain {table} {chain}').split("\n") for line in results: if all(rule_match in line for rule_match in rule_matches): handle_search = re.search('handle (\d+)', line) if handle_search: return handle_search[1] return None def remove_nftables_rule(table, chain, handle): cmd(f'sudo nft delete rule {table} {chain} handle {handle}') # Functions below used by template generation def nft_action(vyos_action): if vyos_action == 'accept': return 'return' return vyos_action def parse_rule(rule_conf, hook, fw_name, rule_id, ip_name): output = [] if ip_name == 'ip6': def_suffix = '6' family = 'ipv6' else: def_suffix = '' family = 'bri' if ip_name == 'bri' else 'ipv4' if 'state' in rule_conf and rule_conf['state']: states = ",".join([s for s in rule_conf['state']]) if states: output.append(f'ct state {{{states}}}') if 'conntrack_helper' in rule_conf: helper_map = {'h323': ['RAS', 'Q.931'], 'nfs': ['rpc'], 'sqlnet': ['tns']} helper_out = [] for helper in rule_conf['conntrack_helper']: if helper in helper_map: helper_out.extend(helper_map[helper]) else: helper_out.append(helper) if helper_out: helper_str = ','.join(f'"{s}"' for s in helper_out) output.append(f'ct helper {{{helper_str}}}') if 'connection_status' in rule_conf and rule_conf['connection_status']: status = rule_conf['connection_status'] if status['nat'] == 'destination': nat_status = 'dnat' output.append(f'ct status {nat_status}') if status['nat'] == 'source': nat_status = 'snat' output.append(f'ct status {nat_status}') if 'protocol' in rule_conf and rule_conf['protocol'] != 'all': proto = rule_conf['protocol'] operator = '' if proto[0] == '!': operator = '!=' proto = proto[1:] if proto == 'tcp_udp': proto = '{tcp, udp}' output.append(f'meta l4proto {operator} {proto}') for side in ['destination', 'source']: if side in rule_conf: prefix = side[0] side_conf = rule_conf[side] address_mask = side_conf.get('address_mask', None) if 'address' in side_conf: suffix = side_conf['address'] operator = '' exclude = suffix[0] == '!' if exclude: operator = '!= ' suffix = suffix[1:] if address_mask: operator = '!=' if exclude else '==' operator = f'& {address_mask} {operator} ' output.append(f'{ip_name} {prefix}addr {operator}{suffix}') if 'fqdn' in side_conf: fqdn = side_conf['fqdn'] hook_name = '' operator = '' if fqdn[0] == '!': operator = '!=' if hook == 'FWD': hook_name = 'forward' if hook == 'INP': hook_name = 'input' if hook == 'OUT': hook_name = 'output' if hook == 'NAM': hook_name = f'name{def_suffix}' output.append(f'{ip_name} {prefix}addr {operator} @FQDN_{hook_name}_{fw_name}_{rule_id}_{prefix}') if dict_search_args(side_conf, 'geoip', 'country_code'): operator = '' hook_name = '' if dict_search_args(side_conf, 'geoip', 'inverse_match') != None: operator = '!=' if hook == 'FWD': hook_name = 'forward' if hook == 'INP': hook_name = 'input' if hook == 'OUT': hook_name = 'output' if hook == 'NAM': hook_name = f'name' output.append(f'{ip_name} {prefix}addr {operator} @GEOIP_CC{def_suffix}_{hook_name}_{fw_name}_{rule_id}') if 'mac_address' in side_conf: suffix = side_conf["mac_address"] if suffix[0] == '!': suffix = f'!= {suffix[1:]}' output.append(f'ether {prefix}addr {suffix}') if 'port' in side_conf: proto = rule_conf['protocol'] port = side_conf['port'].split(',') ports = [] negated_ports = [] for p in port: if p[0] == '!': negated_ports.append(p[1:]) else: ports.append(p) if proto == 'tcp_udp': proto = 'th' if ports: ports_str = ','.join(ports) output.append(f'{proto} {prefix}port {{{ports_str}}}') if negated_ports: negated_ports_str = ','.join(negated_ports) output.append(f'{proto} {prefix}port != {{{negated_ports_str}}}') if 'group' in side_conf: group = side_conf['group'] if 'address_group' in group: group_name = group['address_group'] operator = '' exclude = group_name[0] == "!" if exclude: operator = '!=' group_name = group_name[1:] if address_mask: operator = '!=' if exclude else '==' operator = f'& {address_mask} {operator}' output.append(f'{ip_name} {prefix}addr {operator} @A{def_suffix}_{group_name}') elif 'dynamic_address_group' in group: group_name = group['dynamic_address_group'] operator = '' exclude = group_name[0] == "!" if exclude: operator = '!=' group_name = group_name[1:] output.append(f'{ip_name} {prefix}addr {operator} @DA{def_suffix}_{group_name}') # Generate firewall group domain-group elif 'domain_group' in group: group_name = group['domain_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'{ip_name} {prefix}addr {operator} @D_{group_name}') elif 'network_group' in group: group_name = group['network_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'{ip_name} {prefix}addr {operator} @N{def_suffix}_{group_name}') if 'mac_group' in group: group_name = group['mac_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'ether {prefix}addr {operator} @M_{group_name}') if 'port_group' in group: proto = rule_conf['protocol'] group_name = group['port_group'] if proto == 'tcp_udp': proto = 'th' operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'{proto} {prefix}port {operator} @P_{group_name}') if dict_search_args(rule_conf, 'action') == 'synproxy': output.append('ct state invalid,untracked') if 'hop_limit' in rule_conf: operators = {'eq': '==', 'gt': '>', 'lt': '<'} for op, operator in operators.items(): if op in rule_conf['hop_limit']: value = rule_conf['hop_limit'][op] output.append(f'ip6 hoplimit {operator} {value}') if 'inbound_interface' in rule_conf: operator = '' if 'name' in rule_conf['inbound_interface']: iiface = rule_conf['inbound_interface']['name'] if iiface[0] == '!': operator = '!=' iiface = iiface[1:] output.append(f'iifname {operator} {{{iiface}}}') elif 'group' in rule_conf['inbound_interface']: iiface = rule_conf['inbound_interface']['group'] if iiface[0] == '!': operator = '!=' iiface = iiface[1:] output.append(f'iifname {operator} @I_{iiface}') if 'outbound_interface' in rule_conf: operator = '' if 'name' in rule_conf['outbound_interface']: oiface = rule_conf['outbound_interface']['name'] if oiface[0] == '!': operator = '!=' oiface = oiface[1:] output.append(f'oifname {operator} {{{oiface}}}') elif 'group' in rule_conf['outbound_interface']: oiface = rule_conf['outbound_interface']['group'] if oiface[0] == '!': operator = '!=' oiface = oiface[1:] output.append(f'oifname {operator} @I_{oiface}') if 'ttl' in rule_conf: operators = {'eq': '==', 'gt': '>', 'lt': '<'} for op, operator in operators.items(): if op in rule_conf['ttl']: value = rule_conf['ttl'][op] output.append(f'ip ttl {operator} {value}') for icmp in ['icmp', 'icmpv6']: if icmp in rule_conf: if 'type_name' in rule_conf[icmp]: output.append(icmp + ' type ' + rule_conf[icmp]['type_name']) else: if 'code' in rule_conf[icmp]: output.append(icmp + ' code ' + rule_conf[icmp]['code']) if 'type' in rule_conf[icmp]: output.append(icmp + ' type ' + rule_conf[icmp]['type']) if 'packet_length' in rule_conf: lengths_str = ','.join(rule_conf['packet_length']) output.append(f'ip{def_suffix} length {{{lengths_str}}}') if 'packet_length_exclude' in rule_conf: negated_lengths_str = ','.join(rule_conf['packet_length_exclude']) output.append(f'ip{def_suffix} length != {{{negated_lengths_str}}}') if 'packet_type' in rule_conf: output.append(f'pkttype ' + rule_conf['packet_type']) if 'dscp' in rule_conf: dscp_str = ','.join(rule_conf['dscp']) output.append(f'ip{def_suffix} dscp {{{dscp_str}}}') if 'dscp_exclude' in rule_conf: negated_dscp_str = ','.join(rule_conf['dscp_exclude']) output.append(f'ip{def_suffix} dscp != {{{negated_dscp_str}}}') if 'ipsec' in rule_conf: if 'match_ipsec' in rule_conf['ipsec']: output.append('meta ipsec == 1') if 'match_none' in rule_conf['ipsec']: output.append('meta ipsec == 0') if 'fragment' in rule_conf: # Checking for fragmentation after priority -400 is not possible, # so we use a priority -450 hook to set a mark if 'match_frag' in rule_conf['fragment']: output.append('meta mark 0xffff1') if 'match_non_frag' in rule_conf['fragment']: output.append('meta mark != 0xffff1') if 'limit' in rule_conf: if 'rate' in rule_conf['limit']: output.append(f'limit rate {rule_conf["limit"]["rate"]}') if 'burst' in rule_conf['limit']: output.append(f'burst {rule_conf["limit"]["burst"]} packets') if 'recent' in rule_conf: count = rule_conf['recent']['count'] time = rule_conf['recent']['time'] output.append(f'add @RECENT{def_suffix}_{hook}_{fw_name}_{rule_id} {{ {ip_name} saddr limit rate over {count}/{time} burst {count} packets }}') if 'time' in rule_conf: output.append(parse_time(rule_conf['time'])) tcp_flags = dict_search_args(rule_conf, 'tcp', 'flags') if tcp_flags: output.append(parse_tcp_flags(tcp_flags)) # TCP MSS tcp_mss = dict_search_args(rule_conf, 'tcp', 'mss') if tcp_mss: output.append(f'tcp option maxseg size {tcp_mss}') if 'connection_mark' in rule_conf: conn_mark_str = ','.join(rule_conf['connection_mark']) output.append(f'ct mark {{{conn_mark_str}}}') if 'mark' in rule_conf: mark = rule_conf['mark'] operator = '' if mark[0] == '!': operator = '!=' mark = mark[1:] output.append(f'meta mark {operator} {{{mark}}}') if 'vlan' in rule_conf: if 'id' in rule_conf['vlan']: output.append(f'vlan id {rule_conf["vlan"]["id"]}') if 'priority' in rule_conf['vlan']: output.append(f'vlan pcp {rule_conf["vlan"]["priority"]}') if 'log' in rule_conf: action = rule_conf['action'] if 'action' in rule_conf else 'accept' #output.append(f'log prefix "[{fw_name[:19]}-{rule_id}-{action[:1].upper()}]"') output.append(f'log prefix "[{family}-{hook}-{fw_name}-{rule_id}-{action[:1].upper()}]"') ##{family}-{hook}-{fw_name}-{rule_id} if 'log_options' in rule_conf: if 'level' in rule_conf['log_options']: log_level = rule_conf['log_options']['level'] output.append(f'log level {log_level}') if 'group' in rule_conf['log_options']: log_group = rule_conf['log_options']['group'] output.append(f'log group {log_group}') if 'queue_threshold' in rule_conf['log_options']: queue_threshold = rule_conf['log_options']['queue_threshold'] output.append(f'queue-threshold {queue_threshold}') if 'snapshot_length' in rule_conf['log_options']: log_snaplen = rule_conf['log_options']['snapshot_length'] output.append(f'snaplen {log_snaplen}') output.append('counter') if 'add_address_to_group' in rule_conf: for side in ['destination_address', 'source_address']: if side in rule_conf['add_address_to_group']: prefix = side[0] side_conf = rule_conf['add_address_to_group'][side] dyn_group = side_conf['address_group'] if 'timeout' in side_conf: timeout_value = side_conf['timeout'] output.append(f'set update ip{def_suffix} {prefix}addr timeout {timeout_value} @DA{def_suffix}_{dyn_group}') else: output.append(f'set update ip{def_suffix} saddr @DA{def_suffix}_{dyn_group}') if 'set' in rule_conf: output.append(parse_policy_set(rule_conf['set'], def_suffix)) if 'action' in rule_conf: # Change action=return to action=action # #output.append(nft_action(rule_conf['action'])) if rule_conf['action'] == 'offload': offload_target = rule_conf['offload_target'] output.append(f'flow add @VYOS_FLOWTABLE_{offload_target}') else: output.append(f'{rule_conf["action"]}') if 'jump' in rule_conf['action']: target = rule_conf['jump_target'] output.append(f'NAME{def_suffix}_{target}') if 'queue' in rule_conf['action']: if 'queue' in rule_conf: target = rule_conf['queue'] output.append(f'num {target}') if 'queue_options' in rule_conf: queue_opts = ','.join(rule_conf['queue_options']) output.append(f'{queue_opts}') # Synproxy if 'synproxy' in rule_conf: synproxy_mss = dict_search_args(rule_conf, 'synproxy', 'tcp', 'mss') if synproxy_mss: output.append(f'mss {synproxy_mss}') synproxy_ws = dict_search_args(rule_conf, 'synproxy', 'tcp', 'window_scale') if synproxy_ws: output.append(f'wscale {synproxy_ws} timestamp sack-perm') else: output.append('return') output.append(f'comment "{family}-{hook}-{fw_name}-{rule_id}"') return " ".join(output) def parse_tcp_flags(flags): include = [flag for flag in flags if flag != 'not'] exclude = list(flags['not']) if 'not' in flags else [] return f'tcp flags & ({"|".join(include + exclude)}) == {"|".join(include) if include else "0x0"}' def parse_time(time): out = [] if 'startdate' in time: start = time['startdate'] if 'T' not in start and 'starttime' in time: start += f' {time["starttime"]}' out.append(f'time >= "{start}"') if 'starttime' in time and 'startdate' not in time: out.append(f'hour >= "{time["starttime"]}"') if 'stopdate' in time: stop = time['stopdate'] if 'T' not in stop and 'stoptime' in time: stop += f' {time["stoptime"]}' out.append(f'time < "{stop}"') if 'stoptime' in time and 'stopdate' not in time: out.append(f'hour < "{time["stoptime"]}"') if 'weekdays' in time: days = time['weekdays'].split(",") out_days = [f'"{day}"' for day in days if day[0] != '!'] out.append(f'day {{{",".join(out_days)}}}') return " ".join(out) def parse_policy_set(set_conf, def_suffix): out = [] if 'connection_mark' in set_conf: conn_mark = set_conf['connection_mark'] out.append(f'ct mark set {conn_mark}') if 'dscp' in set_conf: dscp = set_conf['dscp'] out.append(f'ip{def_suffix} dscp set {dscp}') if 'mark' in set_conf: mark = set_conf['mark'] out.append(f'meta mark set {mark}') if 'table' in set_conf: table = set_conf['table'] if table == 'main': table = '254' mark = 0x7FFFFFFF - int(table) out.append(f'meta mark set {mark}') if 'tcp_mss' in set_conf: mss = set_conf['tcp_mss'] out.append(f'tcp option maxseg size set {mss}') return " ".join(out) # GeoIP nftables_geoip_conf = '/run/nftables-geoip.conf' geoip_database = '/usr/share/vyos-geoip/dbip-country-lite.csv.gz' geoip_lock_file = '/run/vyos-geoip.lock' def geoip_load_data(codes=[]): data = None if not os.path.exists(geoip_database): return [] try: with gzip.open(geoip_database, mode='rt') as csv_fh: reader = csv.reader(csv_fh) out = [] for start, end, code in reader: if code.lower() in codes: out.append([start, end, code.lower()]) return out except: print('Error: Failed to open GeoIP database') return [] def geoip_download_data(): url = 'https://download.db-ip.com/free/dbip-country-lite-{}.csv.gz'.format(strftime("%Y-%m")) try: dirname = os.path.dirname(geoip_database) if not os.path.exists(dirname): os.mkdir(dirname) download(geoip_database, url) print("Downloaded GeoIP database") return True except: print("Error: Failed to download GeoIP database") return False class GeoIPLock(object): def __init__(self, file): self.file = file def __enter__(self): if os.path.exists(self.file): return False Path(self.file).touch() return True def __exit__(self, exc_type, exc_value, tb): os.unlink(self.file) def geoip_update(firewall, force=False): with GeoIPLock(geoip_lock_file) as lock: if not lock: print("Script is already running") return False if not firewall: print("Firewall is not configured") return True if not os.path.exists(geoip_database): if not geoip_download_data(): return False elif force: geoip_download_data() ipv4_codes = {} ipv6_codes = {} ipv4_sets = {} ipv6_sets = {} # Map country codes to set names for codes, path in dict_search_recursive(firewall, 'country_code'): set_name = f'GEOIP_CC_{path[1]}_{path[2]}_{path[4]}' if ( path[0] == 'ipv4'): for code in codes: ipv4_codes.setdefault(code, []).append(set_name) elif ( path[0] == 'ipv6' ): set_name = f'GEOIP_CC6_{path[1]}_{path[2]}_{path[4]}' for code in codes: ipv6_codes.setdefault(code, []).append(set_name) if not ipv4_codes and not ipv6_codes: if force: print("GeoIP not in use by firewall") return True geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes]) # Iterate IP blocks to assign to sets for start, end, code in geoip_data: ipv4 = is_ipv4(start) if code in ipv4_codes and ipv4: ip_range = f'{start}-{end}' if start != end else start for setname in ipv4_codes[code]: ipv4_sets.setdefault(setname, []).append(ip_range) if code in ipv6_codes and not ipv4: ip_range = f'{start}-{end}' if start != end else start for setname in ipv6_codes[code]: ipv6_sets.setdefault(setname, []).append(ip_range) render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', { 'ipv4_sets': ipv4_sets, 'ipv6_sets': ipv6_sets }) result = run(f'nft --file {nftables_geoip_conf}') if result != 0: print('Error: GeoIP failed to update firewall') return False return True diff --git a/python/vyos/nat.py b/python/vyos/nat.py index da2613b16..2ada29add 100644 --- a/python/vyos/nat.py +++ b/python/vyos/nat.py @@ -1,313 +1,311 @@ -#!/usr/bin/env python3 -# # Copyright (C) 2022 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. from vyos.template import is_ip_network from vyos.utils.dict import dict_search_args from vyos.template import bracketize_ipv6 def parse_nat_rule(rule_conf, rule_id, nat_type, ipv6=False): output = [] ip_prefix = 'ip6' if ipv6 else 'ip' log_prefix = ('DST' if nat_type == 'destination' else 'SRC') + f'-NAT-{rule_id}' log_suffix = '' if ipv6: log_prefix = log_prefix.replace("NAT-", "NAT66-") ignore_type_addr = False translation_str = '' if 'inbound_interface' in rule_conf: operator = '' if 'name' in rule_conf['inbound_interface']: iiface = rule_conf['inbound_interface']['name'] if iiface[0] == '!': operator = '!=' iiface = iiface[1:] output.append(f'iifname {operator} {{{iiface}}}') else: iiface = rule_conf['inbound_interface']['group'] if iiface[0] == '!': operator = '!=' iiface = iiface[1:] output.append(f'iifname {operator} @I_{iiface}') if 'outbound_interface' in rule_conf: operator = '' if 'name' in rule_conf['outbound_interface']: oiface = rule_conf['outbound_interface']['name'] if oiface[0] == '!': operator = '!=' oiface = oiface[1:] output.append(f'oifname {operator} {{{oiface}}}') else: oiface = rule_conf['outbound_interface']['group'] if oiface[0] == '!': operator = '!=' oiface = oiface[1:] output.append(f'oifname {operator} @I_{oiface}') if 'protocol' in rule_conf and rule_conf['protocol'] != 'all': protocol = rule_conf['protocol'] if protocol == 'tcp_udp': protocol = '{ tcp, udp }' output.append(f'meta l4proto {protocol}') if 'packet_type' in rule_conf: output.append(f'pkttype ' + rule_conf['packet_type']) if 'exclude' in rule_conf: translation_str = 'return' log_suffix = '-EXCL' elif 'translation' in rule_conf: addr = dict_search_args(rule_conf, 'translation', 'address') port = dict_search_args(rule_conf, 'translation', 'port') if 'redirect' in rule_conf['translation']: translation_output = [f'redirect'] redirect_port = dict_search_args(rule_conf, 'translation', 'redirect', 'port') if redirect_port: translation_output.append(f'to {redirect_port}') else: translation_prefix = nat_type[:1] translation_output = [f'{translation_prefix}nat'] if addr and is_ip_network(addr): if not ipv6: map_addr = dict_search_args(rule_conf, nat_type, 'address') if map_addr: if port: translation_output.append(f'{ip_prefix} prefix to {ip_prefix} {translation_prefix}addr map {{ {map_addr} : {addr} . {port} }}') else: translation_output.append(f'{ip_prefix} prefix to {ip_prefix} {translation_prefix}addr map {{ {map_addr} : {addr} }}') ignore_type_addr = True else: translation_output.append(f'prefix to {addr}') else: translation_output.append(f'prefix to {addr}') elif addr == 'masquerade': if port: addr = f'{addr} to ' translation_output = [addr] log_suffix = '-MASQ' else: translation_output.append('to') if addr: addr = bracketize_ipv6(addr) translation_output.append(addr) options = [] addr_mapping = dict_search_args(rule_conf, 'translation', 'options', 'address_mapping') port_mapping = dict_search_args(rule_conf, 'translation', 'options', 'port_mapping') if addr_mapping == 'persistent': options.append('persistent') if port_mapping and port_mapping != 'none': options.append(port_mapping) if ((not addr) or (addr and not is_ip_network(addr))) and port: translation_str = " ".join(translation_output) + (f':{port}') else: translation_str = " ".join(translation_output) if options: translation_str += f' {",".join(options)}' if not ipv6 and 'backend' in rule_conf['load_balance']: hash_input_items = [] current_prob = 0 nat_map = [] for trans_addr, addr in rule_conf['load_balance']['backend'].items(): item_prob = int(addr['weight']) upper_limit = current_prob + item_prob - 1 hash_val = str(current_prob) + '-' + str(upper_limit) element = hash_val + " : " + trans_addr nat_map.append(element) current_prob = current_prob + item_prob elements = ' , '.join(nat_map) if 'hash' in rule_conf['load_balance'] and 'random' in rule_conf['load_balance']['hash']: translation_str += ' numgen random mod 100 map ' + '{ ' + f'{elements}' + ' }' else: for input_param in rule_conf['load_balance']['hash']: if input_param == 'source-address': param = 'ip saddr' elif input_param == 'destination-address': param = 'ip daddr' elif input_param == 'source-port': prot = rule_conf['protocol'] param = f'{prot} sport' elif input_param == 'destination-port': prot = rule_conf['protocol'] param = f'{prot} dport' hash_input_items.append(param) hash_input = ' . '.join(hash_input_items) translation_str += f' jhash ' + f'{hash_input}' + ' mod 100 map ' + '{ ' + f'{elements}' + ' }' for target in ['source', 'destination']: if target not in rule_conf: continue side_conf = rule_conf[target] prefix = target[:1] addr = dict_search_args(side_conf, 'address') if addr and not (ignore_type_addr and target == nat_type): operator = '' if addr[:1] == '!': operator = '!=' addr = addr[1:] output.append(f'{ip_prefix} {prefix}addr {operator} {addr}') addr_prefix = dict_search_args(side_conf, 'prefix') if addr_prefix and ipv6: operator = '' if addr_prefix[:1] == '!': operator = '!=' addr_prefix = addr_prefix[1:] output.append(f'ip6 {prefix}addr {operator} {addr_prefix}') port = dict_search_args(side_conf, 'port') if port: protocol = rule_conf['protocol'] if protocol == 'tcp_udp': protocol = 'th' operator = '' if port[:1] == '!': operator = '!=' port = port[1:] output.append(f'{protocol} {prefix}port {operator} {{ {port} }}') if 'group' in side_conf: group = side_conf['group'] if 'address_group' in group and not (ignore_type_addr and target == nat_type): group_name = group['address_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'{ip_prefix} {prefix}addr {operator} @A_{group_name}') # Generate firewall group domain-group elif 'domain_group' in group and not (ignore_type_addr and target == nat_type): group_name = group['domain_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'{ip_prefix} {prefix}addr {operator} @D_{group_name}') elif 'network_group' in group and not (ignore_type_addr and target == nat_type): group_name = group['network_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'{ip_prefix} {prefix}addr {operator} @N_{group_name}') if 'mac_group' in group: group_name = group['mac_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'ether {prefix}addr {operator} @M_{group_name}') if 'port_group' in group: proto = rule_conf['protocol'] group_name = group['port_group'] if proto == 'tcp_udp': proto = 'th' operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] output.append(f'{proto} {prefix}port {operator} @P_{group_name}') output.append('counter') if 'log' in rule_conf: output.append(f'log prefix "[{log_prefix}{log_suffix}]"') if translation_str: output.append(translation_str) output.append(f'comment "{log_prefix}"') return " ".join(output) def parse_nat_static_rule(rule_conf, rule_id, nat_type): output = [] log_prefix = ('STATIC-DST' if nat_type == 'destination' else 'STATIC-SRC') + f'-NAT-{rule_id}' log_suffix = '' ignore_type_addr = False translation_str = '' if 'inbound_interface' in rule_conf: ifname = rule_conf['inbound_interface'] ifprefix = 'i' if nat_type == 'destination' else 'o' if ifname != 'any': output.append(f'{ifprefix}ifname "{ifname}"') if 'exclude' in rule_conf: translation_str = 'return' log_suffix = '-EXCL' elif 'translation' in rule_conf: translation_prefix = nat_type[:1] translation_output = [f'{translation_prefix}nat'] addr = dict_search_args(rule_conf, 'translation', 'address') map_addr = dict_search_args(rule_conf, 'destination', 'address') if nat_type == 'source': addr, map_addr = map_addr, addr # Swap if addr and is_ip_network(addr): translation_output.append(f'ip prefix to ip {translation_prefix}addr map {{ {map_addr} : {addr} }}') ignore_type_addr = True elif addr: translation_output.append(f'to {addr}') options = [] addr_mapping = dict_search_args(rule_conf, 'translation', 'options', 'address_mapping') port_mapping = dict_search_args(rule_conf, 'translation', 'options', 'port_mapping') if addr_mapping == 'persistent': options.append('persistent') if port_mapping and port_mapping != 'none': options.append(port_mapping) if options: translation_output.append(",".join(options)) translation_str = " ".join(translation_output) prefix = nat_type[:1] addr = dict_search_args(rule_conf, 'translation' if nat_type == 'source' else nat_type, 'address') if addr and not ignore_type_addr: output.append(f'ip {prefix}addr {addr}') output.append('counter') if translation_str: output.append(translation_str) if 'log' in rule_conf: output.append(f'log prefix "[{log_prefix}{log_suffix}]"') output.append(f'comment "{log_prefix}"') return " ".join(output) diff --git a/python/vyos/pki.py b/python/vyos/pki.py index 02dece471..3c577db4d 100644 --- a/python/vyos/pki.py +++ b/python/vyos/pki.py @@ -1,455 +1,453 @@ -#!/usr/bin/env python3 -# # Copyright (C) 2023-2024 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import datetime import ipaddress from cryptography import x509 from cryptography.exceptions import InvalidSignature from cryptography.x509.extensions import ExtensionNotFound from cryptography.x509.oid import NameOID from cryptography.x509.oid import ExtendedKeyUsageOID from cryptography.x509.oid import ExtensionOID from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import dh from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import rsa CERT_BEGIN='-----BEGIN CERTIFICATE-----\n' CERT_END='\n-----END CERTIFICATE-----' KEY_BEGIN='-----BEGIN PRIVATE KEY-----\n' KEY_END='\n-----END PRIVATE KEY-----' KEY_ENC_BEGIN='-----BEGIN ENCRYPTED PRIVATE KEY-----\n' KEY_ENC_END='\n-----END ENCRYPTED PRIVATE KEY-----' KEY_PUB_BEGIN='-----BEGIN PUBLIC KEY-----\n' KEY_PUB_END='\n-----END PUBLIC KEY-----' CRL_BEGIN='-----BEGIN X509 CRL-----\n' CRL_END='\n-----END X509 CRL-----' CSR_BEGIN='-----BEGIN CERTIFICATE REQUEST-----\n' CSR_END='\n-----END CERTIFICATE REQUEST-----' DH_BEGIN='-----BEGIN DH PARAMETERS-----\n' DH_END='\n-----END DH PARAMETERS-----' OVPN_BEGIN = '-----BEGIN OpenVPN Static key V{0}-----\n' OVPN_END = '\n-----END OpenVPN Static key V{0}-----' OPENSSH_KEY_BEGIN='-----BEGIN OPENSSH PRIVATE KEY-----\n' OPENSSH_KEY_END='\n-----END OPENSSH PRIVATE KEY-----' # Print functions encoding_map = { 'PEM': serialization.Encoding.PEM, 'OpenSSH': serialization.Encoding.OpenSSH } public_format_map = { 'SubjectPublicKeyInfo': serialization.PublicFormat.SubjectPublicKeyInfo, 'OpenSSH': serialization.PublicFormat.OpenSSH } private_format_map = { 'PKCS8': serialization.PrivateFormat.PKCS8, 'OpenSSH': serialization.PrivateFormat.OpenSSH } hash_map = { 'sha256': hashes.SHA256, 'sha384': hashes.SHA384, 'sha512': hashes.SHA512, } def get_certificate_fingerprint(cert, hash): hash_algorithm = hash_map[hash]() fp = cert.fingerprint(hash_algorithm) return fp.hex(':').upper() def encode_certificate(cert): return cert.public_bytes(encoding=serialization.Encoding.PEM).decode('utf-8') def encode_public_key(cert, encoding='PEM', key_format='SubjectPublicKeyInfo'): if encoding not in encoding_map: encoding = 'PEM' if key_format not in public_format_map: key_format = 'SubjectPublicKeyInfo' return cert.public_bytes( encoding=encoding_map[encoding], format=public_format_map[key_format]).decode('utf-8') def encode_private_key(private_key, encoding='PEM', key_format='PKCS8', passphrase=None): if encoding not in encoding_map: encoding = 'PEM' if key_format not in private_format_map: key_format = 'PKCS8' encryption = serialization.NoEncryption() if not passphrase else serialization.BestAvailableEncryption(bytes(passphrase, 'utf-8')) return private_key.private_bytes( encoding=encoding_map[encoding], format=private_format_map[key_format], encryption_algorithm=encryption).decode('utf-8') def encode_dh_parameters(dh_parameters): return dh_parameters.parameter_bytes( encoding=serialization.Encoding.PEM, format=serialization.ParameterFormat.PKCS3).decode('utf-8') # EC Helper def get_elliptic_curve(size): curve_func = None name = f'SECP{size}R1' if hasattr(ec, name): curve_func = getattr(ec, name) else: curve_func = ec.SECP256R1() # Default to SECP256R1 return curve_func() # Creation functions def create_private_key(key_type, key_size=None): private_key = None if key_type == 'rsa': private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) elif key_type == 'dsa': private_key = dsa.generate_private_key(key_size=key_size) elif key_type == 'ec': curve = get_elliptic_curve(key_size) private_key = ec.generate_private_key(curve) return private_key def create_certificate_request(subject, private_key, subject_alt_names=[]): subject_obj = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, subject['country']), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, subject['state']), x509.NameAttribute(NameOID.LOCALITY_NAME, subject['locality']), x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject['organization']), x509.NameAttribute(NameOID.COMMON_NAME, subject['common_name'])]) builder = x509.CertificateSigningRequestBuilder() \ .subject_name(subject_obj) if subject_alt_names: alt_names = [] for obj in subject_alt_names: if isinstance(obj, ipaddress.IPv4Address) or isinstance(obj, ipaddress.IPv6Address): alt_names.append(x509.IPAddress(obj)) elif isinstance(obj, str): alt_names.append(x509.DNSName(obj)) if alt_names: builder = builder.add_extension(x509.SubjectAlternativeName(alt_names), critical=False) return builder.sign(private_key, hashes.SHA256()) def add_key_identifier(ca_cert): try: ski_ext = ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier) return x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(ski_ext.value) except: return x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key()) def create_certificate(cert_req, ca_cert, ca_private_key, valid_days=365, cert_type='server', is_ca=False, is_sub_ca=False): ext_key_usage = [] if is_ca: ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH] elif cert_type == 'client': ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH] elif cert_type == 'server': ext_key_usage = [ExtendedKeyUsageOID.SERVER_AUTH] builder = x509.CertificateBuilder() \ .subject_name(cert_req.subject) \ .issuer_name(ca_cert.subject) \ .public_key(cert_req.public_key()) \ .serial_number(x509.random_serial_number()) \ .not_valid_before(datetime.datetime.utcnow()) \ .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=int(valid_days))) builder = builder.add_extension(x509.BasicConstraints(ca=is_ca, path_length=0 if is_sub_ca else None), critical=True) builder = builder.add_extension(x509.KeyUsage( digital_signature=True, content_commitment=False, key_encipherment=False, data_encipherment=False, key_agreement=False, key_cert_sign=is_ca, crl_sign=is_ca, encipher_only=False, decipher_only=False), critical=True) builder = builder.add_extension(x509.ExtendedKeyUsage(ext_key_usage), critical=False) builder = builder.add_extension(x509.SubjectKeyIdentifier.from_public_key(cert_req.public_key()), critical=False) if not is_ca or is_sub_ca: builder = builder.add_extension(add_key_identifier(ca_cert), critical=False) for ext in cert_req.extensions: builder = builder.add_extension(ext.value, critical=False) return builder.sign(ca_private_key, hashes.SHA256()) def create_certificate_revocation_list(ca_cert, ca_private_key, serial_numbers=[]): if not serial_numbers: return False builder = x509.CertificateRevocationListBuilder() \ .issuer_name(ca_cert.subject) \ .last_update(datetime.datetime.today()) \ .next_update(datetime.datetime.today() + datetime.timedelta(1, 0, 0)) for serial_number in serial_numbers: revoked_cert = x509.RevokedCertificateBuilder() \ .serial_number(serial_number) \ .revocation_date(datetime.datetime.today()) \ .build() builder = builder.add_revoked_certificate(revoked_cert) return builder.sign(private_key=ca_private_key, algorithm=hashes.SHA256()) def create_dh_parameters(bits=2048): if not bits or bits < 512: print("Invalid DH parameter key size") return False return dh.generate_parameters(generator=2, key_size=int(bits)) # Wrap functions def wrap_public_key(raw_data): return KEY_PUB_BEGIN + raw_data + KEY_PUB_END def wrap_private_key(raw_data, passphrase=None): return (KEY_ENC_BEGIN if passphrase else KEY_BEGIN) + raw_data + (KEY_ENC_END if passphrase else KEY_END) def wrap_openssh_public_key(raw_data, type): return f'{type} {raw_data}' def wrap_openssh_private_key(raw_data): return OPENSSH_KEY_BEGIN + raw_data + OPENSSH_KEY_END def wrap_certificate_request(raw_data): return CSR_BEGIN + raw_data + CSR_END def wrap_certificate(raw_data): return CERT_BEGIN + raw_data + CERT_END def wrap_crl(raw_data): return CRL_BEGIN + raw_data + CRL_END def wrap_dh_parameters(raw_data): return DH_BEGIN + raw_data + DH_END def wrap_openvpn_key(raw_data, version='1'): return OVPN_BEGIN.format(version) + raw_data + OVPN_END.format(version) # Load functions def load_public_key(raw_data, wrap_tags=True): if wrap_tags: raw_data = wrap_public_key(raw_data) try: return serialization.load_pem_public_key(bytes(raw_data, 'utf-8')) except ValueError: return False def load_private_key(raw_data, passphrase=None, wrap_tags=True): if wrap_tags: raw_data = wrap_private_key(raw_data, passphrase) if passphrase is not None: passphrase = bytes(passphrase, 'utf-8') try: return serialization.load_pem_private_key(bytes(raw_data, 'utf-8'), password=passphrase) except ValueError: return False def load_openssh_public_key(raw_data, type): try: return serialization.load_ssh_public_key(bytes(f'{type} {raw_data}', 'utf-8')) except ValueError: return False def load_openssh_private_key(raw_data, passphrase=None, wrap_tags=True): if wrap_tags: raw_data = wrap_openssh_private_key(raw_data) try: return serialization.load_ssh_private_key(bytes(raw_data, 'utf-8'), password=passphrase) except ValueError: return False def load_certificate_request(raw_data, wrap_tags=True): if wrap_tags: raw_data = wrap_certificate_request(raw_data) try: return x509.load_pem_x509_csr(bytes(raw_data, 'utf-8')) except ValueError: return False def load_certificate(raw_data, wrap_tags=True): if wrap_tags: raw_data = wrap_certificate(raw_data) try: return x509.load_pem_x509_certificate(bytes(raw_data, 'utf-8')) except ValueError: return False def load_crl(raw_data, wrap_tags=True): if wrap_tags: raw_data = wrap_crl(raw_data) try: return x509.load_pem_x509_crl(bytes(raw_data, 'utf-8')) except ValueError: return False def load_dh_parameters(raw_data, wrap_tags=True): if wrap_tags: raw_data = wrap_dh_parameters(raw_data) try: return serialization.load_pem_parameters(bytes(raw_data, 'utf-8')) except ValueError: return False # Verify def is_ca_certificate(cert): if not cert: return False try: ext = cert.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS) return ext.value.ca except ExtensionNotFound: return False def verify_certificate(cert, ca_cert): # Verify certificate was signed by specified CA if ca_cert.subject != cert.issuer: return False ca_public_key = ca_cert.public_key() try: if isinstance(ca_public_key, rsa.RSAPublicKeyWithSerialization): ca_public_key.verify( cert.signature, cert.tbs_certificate_bytes, padding=padding.PKCS1v15(), algorithm=cert.signature_hash_algorithm) elif isinstance(ca_public_key, dsa.DSAPublicKeyWithSerialization): ca_public_key.verify( cert.signature, cert.tbs_certificate_bytes, algorithm=cert.signature_hash_algorithm) elif isinstance(ca_public_key, ec.EllipticCurvePublicKeyWithSerialization): ca_public_key.verify( cert.signature, cert.tbs_certificate_bytes, signature_algorithm=ec.ECDSA(cert.signature_hash_algorithm)) else: return False # We cannot verify it return True except InvalidSignature: return False def verify_crl(crl, ca_cert): # Verify CRL was signed by specified CA if ca_cert.subject != crl.issuer: return False ca_public_key = ca_cert.public_key() try: if isinstance(ca_public_key, rsa.RSAPublicKeyWithSerialization): ca_public_key.verify( crl.signature, crl.tbs_certlist_bytes, padding=padding.PKCS1v15(), algorithm=crl.signature_hash_algorithm) elif isinstance(ca_public_key, dsa.DSAPublicKeyWithSerialization): ca_public_key.verify( crl.signature, crl.tbs_certlist_bytes, algorithm=crl.signature_hash_algorithm) elif isinstance(ca_public_key, ec.EllipticCurvePublicKeyWithSerialization): ca_public_key.verify( crl.signature, crl.tbs_certlist_bytes, signature_algorithm=ec.ECDSA(crl.signature_hash_algorithm)) else: return False # We cannot verify it return True except InvalidSignature: return False def verify_ca_chain(sorted_names, pki_node): if len(sorted_names) == 1: # Single cert, no chain return True for name in sorted_names: cert = load_certificate(pki_node[name]['certificate']) verified = False for ca_name in sorted_names: if name == ca_name: continue ca_cert = load_certificate(pki_node[ca_name]['certificate']) if verify_certificate(cert, ca_cert): verified = True break if not verified and name != sorted_names[-1]: # Only permit top-most certificate to fail verify (e.g. signed by public CA not explicitly in chain) return False return True # Certificate chain def find_parent(cert, ca_certs): for ca_cert in ca_certs: if verify_certificate(cert, ca_cert): return ca_cert return None def find_chain(cert, ca_certs): remaining = ca_certs.copy() chain = [cert] while remaining: parent = find_parent(chain[-1], remaining) if parent is None: # No parent in the list of remaining certificates or there's a circular dependency break elif parent == chain[-1]: # Self-signed: must be root CA (end of chain) break else: remaining.remove(parent) chain.append(parent) return chain def sort_ca_chain(ca_names, pki_node): def ca_cmp(ca_name1, ca_name2, pki_node): cert1 = load_certificate(pki_node[ca_name1]['certificate']) cert2 = load_certificate(pki_node[ca_name2]['certificate']) if verify_certificate(cert1, cert2): # cert1 is child of cert2 return -1 return 1 from functools import cmp_to_key return sorted(ca_names, key=cmp_to_key(lambda cert1, cert2: ca_cmp(cert1, cert2, pki_node)))