diff --git a/debian/rules b/debian/rules index 9ada2bf87..e6bbeeafb 100755 --- a/debian/rules +++ b/debian/rules @@ -1,133 +1,130 @@ #!/usr/bin/make -f DIR := debian/tmp VYOS_SBIN_DIR := usr/sbin VYOS_BIN_DIR := usr/bin VYOS_LIBEXEC_DIR := usr/libexec/vyos VYOS_DATA_DIR := usr/share/vyos VYOS_CFG_TMPL_DIR := opt/vyatta/share/vyatta-cfg/templates VYOS_OP_TMPL_DIR := opt/vyatta/share/vyatta-op/templates VYOS_MIBS_DIR := usr/share/snmp/mibs VYOS_LOCALUI_DIR := srv/localui MIGRATION_SCRIPTS_DIR := opt/vyatta/etc/config-migrate/migrate SYSTEM_SCRIPTS_DIR := usr/libexec/vyos/system SERVICES_DIR := usr/libexec/vyos/services DEB_TARGET_ARCH := $(shell dpkg-architecture -qDEB_TARGET_ARCH) %: dh $@ --with python3, --with quilt # Skip dh_strip_nondeterminism - this is very time consuming # and we have no non deterministic output (yet) override_dh_strip_nondeterminism: override_dh_gencontrol: dh_gencontrol -- -v$(shell (git describe --tags --long --match 'vyos/*' --dirty 2>/dev/null || echo 0.0-no.git.tag) | sed -E 's%vyos/%%' | sed -E 's%-dirty%+dirty%') override_dh_auto_build: make all override_dh_auto_install: dh_auto_install - # convert the XML to dictionaries - env PYTHONPATH=python python3 python/vyos/xml/generate.py - cd python; python3 setup.py install --install-layout=deb --root ../$(DIR); cd .. # Install scripts mkdir -p $(DIR)/$(VYOS_SBIN_DIR) mkdir -p $(DIR)/$(VYOS_BIN_DIR) cp -r src/utils/* $(DIR)/$(VYOS_BIN_DIR) cp src/shim/vyshim $(DIR)/$(VYOS_SBIN_DIR) # Install conf mode scripts mkdir -p $(DIR)/$(VYOS_LIBEXEC_DIR)/conf_mode cp -r src/conf_mode/* $(DIR)/$(VYOS_LIBEXEC_DIR)/conf_mode # Install op mode scripts mkdir -p $(DIR)/$(VYOS_LIBEXEC_DIR)/op_mode cp -r src/op_mode/* $(DIR)/$(VYOS_LIBEXEC_DIR)/op_mode # Install op mode scripts mkdir -p $(DIR)/$(VYOS_LIBEXEC_DIR)/init cp -r src/init/* $(DIR)/$(VYOS_LIBEXEC_DIR)/init # Install validators mkdir -p $(DIR)/$(VYOS_LIBEXEC_DIR)/validators cp -r src/validators/* $(DIR)/$(VYOS_LIBEXEC_DIR)/validators # Install completion helpers mkdir -p $(DIR)/$(VYOS_LIBEXEC_DIR)/completion cp -r src/completion/* $(DIR)/$(VYOS_LIBEXEC_DIR)/completion # Install helper scripts cp -r src/helpers/* $(DIR)/$(VYOS_LIBEXEC_DIR)/ # Install migration scripts mkdir -p $(DIR)/$(MIGRATION_SCRIPTS_DIR) cp -r src/migration-scripts/* $(DIR)/$(MIGRATION_SCRIPTS_DIR) # Install system scripts mkdir -p $(DIR)/$(SYSTEM_SCRIPTS_DIR) cp -r src/system/* $(DIR)/$(SYSTEM_SCRIPTS_DIR) # Install system services mkdir -p $(DIR)/$(SERVICES_DIR) cp -r src/services/* $(DIR)/$(SERVICES_DIR) # Install configuration command definitions mkdir -p $(DIR)/$(VYOS_CFG_TMPL_DIR) cp -r templates-cfg/* $(DIR)/$(VYOS_CFG_TMPL_DIR) # Install operational command definitions mkdir -p $(DIR)/$(VYOS_OP_TMPL_DIR) cp -r templates-op/* $(DIR)/$(VYOS_OP_TMPL_DIR) # Install data files mkdir -p $(DIR)/$(VYOS_DATA_DIR) cp -r data/* $(DIR)/$(VYOS_DATA_DIR) # Create localui dir mkdir -p $(DIR)/$(VYOS_LOCALUI_DIR) # Install SNMP MIBs mkdir -p $(DIR)/$(VYOS_MIBS_DIR) cp -d mibs/* $(DIR)/$(VYOS_MIBS_DIR) # Install etc configuration files mkdir -p $(DIR)/etc cp -r src/etc/* $(DIR)/etc # Install PAM configuration snippets mkdir -p $(DIR)/usr/share/pam-configs cp -r src/pam-configs/* $(DIR)/usr/share/pam-configs # Install systemd service units mkdir -p $(DIR)/lib/systemd/system cp -r src/systemd/* $(DIR)/lib/systemd/system # Make directory for generated configuration file mkdir -p $(DIR)/etc/vyos # Install smoke test scripts mkdir -p $(DIR)/$(VYOS_LIBEXEC_DIR)/tests/smoke/ cp -r smoketest/scripts/* $(DIR)/$(VYOS_LIBEXEC_DIR)/tests/smoke # Install smoke test configs mkdir -p $(DIR)/$(VYOS_LIBEXEC_DIR)/tests/config/ cp -r smoketest/configs/* $(DIR)/$(VYOS_LIBEXEC_DIR)/tests/config # Install system programs mkdir -p $(DIR)/$(VYOS_BIN_DIR) cp -r smoketest/bin/* $(DIR)/$(VYOS_BIN_DIR) # Install udev script mkdir -p $(DIR)/usr/lib/udev cp src/helpers/vyos_net_name $(DIR)/usr/lib/udev override_dh_installsystemd: dh_installsystemd -pvyos-1x --name vyos-router vyos-router.service dh_installsystemd -pvyos-1x --name vyos vyos.target diff --git a/python/vyos/component_version.py b/python/vyos/component_version.py index a4e318d08..84e0ae51a 100644 --- a/python/vyos/component_version.py +++ b/python/vyos/component_version.py @@ -1,192 +1,192 @@ # Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. """ Functions for reading/writing component versions. The config file version string has the following form: VyOS 1.3/1.4: // Warning: Do not remove the following line. // vyos-config-version: "broadcast-relay@1:cluster@1:config-management@1:conntrack@3:conntrack-sync@2:dhcp-relay@2:dhcp-server@6:dhcpv6-server@1:dns-forwarding@3:firewall@5:https@2:interfaces@22:ipoe-server@1:ipsec@5:isis@1:l2tp@3:lldp@1:mdns@1:nat@5:ntp@1:pppoe-server@5:pptp@2:qos@1:quagga@8:rpki@1:salt@1:snmp@2:ssh@2:sstp@3:system@21:vrrp@2:vyos-accel-ppp@2:wanloadbalance@3:webproxy@2:zone-policy@1" // Release version: 1.3.0 VyOS 1.2: /* Warning: Do not remove the following line. */ /* === vyatta-config-version: "broadcast-relay@1:cluster@1:config-management@1:conntrack-sync@1:conntrack@1:dhcp-relay@2:dhcp-server@5:dns-forwarding@1:firewall@5:ipsec@5:l2tp@1:mdns@1:nat@4:ntp@1:pppoe-server@2:pptp@1:qos@1:quagga@7:snmp@1:ssh@1:system@10:vrrp@2:wanloadbalance@3:webgui@1:webproxy@2:zone-policy@1" === */ /* Release version: 1.2.8 */ """ import os import re import sys import fileinput -from vyos.xml import component_version +from vyos.xml_ref import component_version from vyos.version import get_version from vyos.defaults import directories DEFAULT_CONFIG_PATH = os.path.join(directories['config'], 'config.boot') def from_string(string_line, vintage='vyos'): """ Get component version dictionary from string. Return empty dictionary if string contains no config information or raise error if component version string malformed. """ version_dict = {} if vintage == 'vyos': if re.match(r'// vyos-config-version:.+', string_line): if not re.match(r'// vyos-config-version:\s+"([\w,-]+@\d+:)+([\w,-]+@\d+)"\s*', string_line): raise ValueError(f"malformed configuration string: {string_line}") for pair in re.findall(r'([\w,-]+)@(\d+)', string_line): version_dict[pair[0]] = int(pair[1]) elif vintage == 'vyatta': if re.match(r'/\* === vyatta-config-version:.+=== \*/$', string_line): if not re.match(r'/\* === vyatta-config-version:\s+"([\w,-]+@\d+:)+([\w,-]+@\d+)"\s+=== \*/$', string_line): raise ValueError(f"malformed configuration string: {string_line}") for pair in re.findall(r'([\w,-]+)@(\d+)', string_line): version_dict[pair[0]] = int(pair[1]) else: raise ValueError("Unknown config string vintage") return version_dict def from_file(config_file_name=DEFAULT_CONFIG_PATH, vintage='vyos'): """ Get component version dictionary parsing config file line by line """ with open(config_file_name, 'r') as f: for line_in_config in f: version_dict = from_string(line_in_config, vintage=vintage) if version_dict: return version_dict # no version information return {} def from_system(): """ Get system component version dict. """ return component_version() def legacy_from_system(): """ Get system component version dict from legacy location. This is for a transitional sanity check; the directory will eventually be removed. """ system_versions = {} legacy_dir = directories['current'] # To be removed: if not os.path.isdir(legacy_dir): return system_versions try: version_info = os.listdir(legacy_dir) except OSError as err: sys.exit(repr(err)) for info in version_info: if re.match(r'[\w,-]+@\d+', info): pair = info.split('@') system_versions[pair[0]] = int(pair[1]) return system_versions def format_string(ver: dict) -> str: """ Version dict to string. """ keys = list(ver) keys.sort() l = [] for k in keys: v = ver[k] l.append(f'{k}@{v}') sep = ':' return sep.join(l) def version_footer(ver: dict, vintage='vyos') -> str: """ Version footer as string. """ ver_str = format_string(ver) release = get_version() if vintage == 'vyos': ret_str = (f'// Warning: Do not remove the following line.\n' + f'// vyos-config-version: "{ver_str}"\n' + f'// Release version: {release}\n') elif vintage == 'vyatta': ret_str = (f'/* Warning: Do not remove the following line. */\n' + f'/* === vyatta-config-version: "{ver_str}" === */\n' + f'/* Release version: {release} */\n') else: raise ValueError("Unknown config string vintage") return ret_str def system_footer(vintage='vyos') -> str: """ System version footer as string. """ ver_d = from_system() return version_footer(ver_d, vintage=vintage) def write_version_footer(ver: dict, file_name, vintage='vyos'): """ Write version footer to file. """ footer = version_footer(ver=ver, vintage=vintage) if file_name: with open(file_name, 'a') as f: f.write(footer) else: sys.stdout.write(footer) def write_system_footer(file_name, vintage='vyos'): """ Write system version footer to file. """ ver_d = from_system() return write_version_footer(ver_d, file_name=file_name, vintage=vintage) def remove_footer(file_name): """ Remove old version footer. """ for line in fileinput.input(file_name, inplace=True): if re.match(r'/\* Warning:.+ \*/$', line): continue if re.match(r'/\* === vyatta-config-version:.+=== \*/$', line): continue if re.match(r'/\* Release version:.+ \*/$', line): continue if re.match('// vyos-config-version:.+', line): continue if re.match('// Warning:.+', line): continue if re.match('// Release version:.+', line): continue sys.stdout.write(line) diff --git a/python/vyos/configdict.py b/python/vyos/configdict.py index 2a47e88f9..71a06b625 100644 --- a/python/vyos/configdict.py +++ b/python/vyos/configdict.py @@ -1,652 +1,651 @@ # Copyright 2019-2022 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. """ A library for retrieving value dicts from VyOS configs in a declarative fashion. """ import os import json from vyos.utils.dict import dict_search -from vyos.xml import defaults from vyos.utils.process import cmd def retrieve_config(path_hash, base_path, config): """ Retrieves a VyOS config as a dict according to a declarative description The description dict, passed in the first argument, must follow this format: ``field_name : <path, type, [inner_options_dict]>``. Supported types are: ``str`` (for normal nodes), ``list`` (returns a list of strings, for multi nodes), ``bool`` (returns True if valueless node exists), ``dict`` (for tag nodes, returns a dict indexed by node names, according to description in the third item of the tuple). Args: path_hash (dict): Declarative description of the config to retrieve base_path (list): A base path to prepend to all option paths config (vyos.config.Config): A VyOS config object Returns: dict: config dict """ config_hash = {} for k in path_hash: if type(path_hash[k]) != tuple: raise ValueError("In field {0}: expected a tuple, got a value {1}".format(k, str(path_hash[k]))) if len(path_hash[k]) < 2: raise ValueError("In field {0}: field description must be a tuple of at least two items, path (list) and type".format(k)) path = path_hash[k][0] if type(path) != list: raise ValueError("In field {0}: path must be a list, not a {1}".format(k, type(path))) typ = path_hash[k][1] if type(typ) != type: raise ValueError("In field {0}: type must be a type, not a {1}".format(k, type(typ))) path = base_path + path path_str = " ".join(path) if typ == str: config_hash[k] = config.return_value(path_str) elif typ == list: config_hash[k] = config.return_values(path_str) elif typ == bool: config_hash[k] = config.exists(path_str) elif typ == dict: try: inner_hash = path_hash[k][2] except IndexError: raise ValueError("The type of the \'{0}\' field is dict, but inner options hash is missing from the tuple".format(k)) config_hash[k] = {} nodes = config.list_nodes(path_str) for node in nodes: config_hash[k][node] = retrieve_config(inner_hash, path + [node], config) return config_hash def dict_merge(source, destination): """ Merge two dictionaries. Only keys which are not present in destination will be copied from source, anything else will be kept untouched. Function will return a new dict which has the merged key/value pairs. """ from copy import deepcopy tmp = deepcopy(destination) for key, value in source.items(): if key not in tmp: tmp[key] = value elif isinstance(source[key], dict): tmp[key] = dict_merge(source[key], tmp[key]) return tmp def list_diff(first, second): """ Diff two dictionaries and return only unique items """ second = set(second) return [item for item in first if item not in second] def is_node_changed(conf, path): from vyos.configdiff import get_config_diff D = get_config_diff(conf, key_mangling=('-', '_')) return D.is_node_changed(path) def leaf_node_changed(conf, path): """ Check if a leaf node was altered. If it has been altered - values has been changed, or it was added/removed, we will return a list containing the old value(s). If nothing has been changed, None is returned. NOTE: path must use the real CLI node name (e.g. with a hyphen!) """ from vyos.configdiff import get_config_diff D = get_config_diff(conf, key_mangling=('-', '_')) (new, old) = D.get_value_diff(path) if new != old: if isinstance(old, dict): # valueLess nodes return {} if node is deleted return True if old is None and isinstance(new, dict): # valueLess nodes return {} if node was added return True if old is None: return [] if isinstance(old, str): return [old] if isinstance(old, list): if isinstance(new, str): new = [new] elif isinstance(new, type(None)): new = [] return list_diff(old, new) return None def node_changed(conf, path, key_mangling=None, recursive=False): """ Check if a leaf node was altered. If it has been altered - values has been changed, or it was added/removed, we will return the old value. If nothing has been changed, None is returned """ from vyos.configdiff import get_config_diff, Diff D = get_config_diff(conf, key_mangling) # get_child_nodes() will return dict_keys(), mangle this into a list with PEP448 keys = D.get_child_nodes_diff(path, expand_nodes=Diff.DELETE, recursive=recursive)['delete'].keys() return list(keys) def get_removed_vlans(conf, path, dict): """ Common function to parse a dictionary retrieved via get_config_dict() and determine any added/removed VLAN interfaces - be it 802.1q or Q-in-Q. """ from vyos.configdiff import get_config_diff, Diff # Check vif, vif-s/vif-c VLAN interfaces for removal D = get_config_diff(conf, key_mangling=('-', '_')) D.set_level(conf.get_level()) # get_child_nodes() will return dict_keys(), mangle this into a list with PEP448 keys = D.get_child_nodes_diff(path + ['vif'], expand_nodes=Diff.DELETE)['delete'].keys() if keys: dict['vif_remove'] = [*keys] # get_child_nodes() will return dict_keys(), mangle this into a list with PEP448 keys = D.get_child_nodes_diff(path + ['vif-s'], expand_nodes=Diff.DELETE)['delete'].keys() if keys: dict['vif_s_remove'] = [*keys] for vif in dict.get('vif_s', {}).keys(): keys = D.get_child_nodes_diff(path + ['vif-s', vif, 'vif-c'], expand_nodes=Diff.DELETE)['delete'].keys() if keys: dict['vif_s'][vif]['vif_c_remove'] = [*keys] return dict def is_member(conf, interface, intftype=None): """ Checks if passed interface is member of other interface of specified type. intftype is optional, if not passed it will search all known types (currently bridge and bonding) Returns: dict empty -> Interface is not a member key -> Interface is a member of this interface """ from vyos.ifconfig import Section ret_val = {} intftypes = ['bonding', 'bridge'] if intftype not in intftypes + [None]: raise ValueError(( f'unknown interface type "{intftype}" or it cannot ' f'have member interfaces')) intftype = intftypes if intftype == None else [intftype] for iftype in intftype: base = ['interfaces', iftype] for intf in conf.list_nodes(base): member = base + [intf, 'member', 'interface', interface] if conf.exists(member): tmp = conf.get_config_dict(member, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) ret_val.update({intf : tmp}) return ret_val def is_mirror_intf(conf, interface, direction=None): """ Check whether the passed interface is used for port mirroring. Direction is optional, if not passed it will search all known direction (currently ingress and egress) Returns: None -> Interface is not a monitor interface Array() -> This interface is a monitor interface of interfaces """ from vyos.ifconfig import Section directions = ['ingress', 'egress'] if direction not in directions + [None]: raise ValueError(f'Unknown interface mirror direction "{direction}"') direction = directions if direction == None else [direction] ret_val = None base = ['interfaces'] for dir in direction: for iftype in conf.list_nodes(base): iftype_base = base + [iftype] for intf in conf.list_nodes(iftype_base): mirror = iftype_base + [intf, 'mirror', dir, interface] if conf.exists(mirror): path = ['interfaces', Section.section(intf), intf] tmp = conf.get_config_dict(path, key_mangling=('-', '_'), get_first_key=True) ret_val = {intf : tmp} return ret_val def has_address_configured(conf, intf): """ Checks if interface has an address configured. Checks the following config nodes: 'address', 'ipv6 address eui64', 'ipv6 address autoconf' Returns True if interface has address configured, False if it doesn't. """ from vyos.ifconfig import Section ret = False old_level = conf.get_level() conf.set_level([]) intfpath = 'interfaces ' + Section.get_config_path(intf) if ( conf.exists(f'{intfpath} address') or conf.exists(f'{intfpath} ipv6 address autoconf') or conf.exists(f'{intfpath} ipv6 address eui64') ): ret = True conf.set_level(old_level) return ret def has_vrf_configured(conf, intf): """ Checks if interface has a VRF configured. Returns True if interface has VRF configured, False if it doesn't. """ from vyos.ifconfig import Section ret = False old_level = conf.get_level() conf.set_level([]) tmp = ['interfaces', Section.get_config_path(intf), 'vrf'] if conf.exists(tmp): ret = True conf.set_level(old_level) return ret def has_vlan_subinterface_configured(conf, intf): """ Checks if interface has an VLAN subinterface configured. Checks the following config nodes: 'vif', 'vif-s' Return True if interface has VLAN subinterface configured. """ from vyos.ifconfig import Section ret = False intfpath = ['interfaces', Section.section(intf), intf] if ( conf.exists(intfpath + ['vif']) or conf.exists(intfpath + ['vif-s'])): ret = True return ret def is_source_interface(conf, interface, intftype=None): """ Checks if passed interface is configured as source-interface of other interfaces of specified type. intftype is optional, if not passed it will search all known types (currently pppoe, macsec, pseudo-ethernet, tunnel and vxlan) Returns: None -> Interface is not a member interface name -> Interface is a member of this interface False -> interface type cannot have members """ ret_val = None intftypes = ['macsec', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan'] if not intftype: intftype = intftypes if isinstance(intftype, str): intftype = [intftype] elif not isinstance(intftype, list): raise ValueError(f'Interface type "{type(intftype)}" must be either str or list!') if not all(x in intftypes for x in intftype): raise ValueError(f'unknown interface type "{intftype}" or it can not ' 'have a source-interface') for it in intftype: base = ['interfaces', it] for intf in conf.list_nodes(base): src_intf = base + [intf, 'source-interface'] if conf.exists(src_intf) and interface in conf.return_values(src_intf): ret_val = intf break return ret_val def get_dhcp_interfaces(conf, vrf=None): """ Common helper functions to retrieve all interfaces from current CLI sessions that have DHCP configured. """ dhcp_interfaces = {} dict = conf.get_config_dict(['interfaces'], get_first_key=True) if not dict: return dhcp_interfaces def check_dhcp(config): ifname = config['ifname'] tmp = {} if 'address' in config and 'dhcp' in config['address']: options = {} if dict_search('dhcp_options.default_route_distance', config) != None: options.update({'dhcp_options' : config['dhcp_options']}) if 'vrf' in config: if vrf == config['vrf']: tmp.update({ifname : options}) else: if vrf is None: tmp.update({ifname : options}) return tmp for section, interface in dict.items(): for ifname in interface: # always reset config level, as get_interface_dict() will alter it conf.set_level([]) # we already have a dict representation of the config from get_config_dict(), # but with the extended information from get_interface_dict() we also # get the DHCP client default-route-distance default option if not specified. _, ifconfig = get_interface_dict(conf, ['interfaces', section], ifname) tmp = check_dhcp(ifconfig) dhcp_interfaces.update(tmp) # check per VLAN interfaces for vif, vif_config in ifconfig.get('vif', {}).items(): tmp = check_dhcp(vif_config) dhcp_interfaces.update(tmp) # check QinQ VLAN interfaces for vif_s, vif_s_config in ifconfig.get('vif_s', {}).items(): tmp = check_dhcp(vif_s_config) dhcp_interfaces.update(tmp) for vif_c, vif_c_config in vif_s_config.get('vif_c', {}).items(): tmp = check_dhcp(vif_c_config) dhcp_interfaces.update(tmp) return dhcp_interfaces def get_pppoe_interfaces(conf, vrf=None): """ Common helper functions to retrieve all interfaces from current CLI sessions that have DHCP configured. """ pppoe_interfaces = {} conf.set_level([]) for ifname in conf.list_nodes(['interfaces', 'pppoe']): # always reset config level, as get_interface_dict() will alter it conf.set_level([]) # we already have a dict representation of the config from get_config_dict(), # but with the extended information from get_interface_dict() we also # get the DHCP client default-route-distance default option if not specified. _, ifconfig = get_interface_dict(conf, ['interfaces', 'pppoe'], ifname) options = {} if 'default_route_distance' in ifconfig: options.update({'default_route_distance' : ifconfig['default_route_distance']}) if 'no_default_route' in ifconfig: options.update({'no_default_route' : {}}) if 'vrf' in ifconfig: if vrf == ifconfig['vrf']: pppoe_interfaces.update({ifname : options}) else: if vrf is None: pppoe_interfaces.update({ifname : options}) return pppoe_interfaces def get_interface_dict(config, base, ifname='', recursive_defaults=True): """ Common utility function to retrieve and mangle the interfaces configuration from the CLI input nodes. All interfaces have a common base where value retrival is identical. This function must be used whenever possible when working on the interfaces node! Return a dictionary with the necessary interface config keys. """ if not ifname: from vyos import ConfigError # determine tagNode instance if 'VYOS_TAGNODE_VALUE' not in os.environ: raise ConfigError('Interface (VYOS_TAGNODE_VALUE) not specified') ifname = os.environ['VYOS_TAGNODE_VALUE'] # Check if interface has been removed. We must use exists() as # get_config_dict() will always return {} - even when an empty interface # node like the following exists. # +macsec macsec1 { # +} if not config.exists(base + [ifname]): dict = config.get_config_dict(base + [ifname], key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) dict.update({'deleted' : {}}) else: # Get config_dict with default values dict = config.get_config_dict(base + [ifname], key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True, with_defaults=True, with_recursive_defaults=recursive_defaults) # If interface does not request an IPv4 DHCP address there is no need # to keep the dhcp-options key if 'address' not in dict or 'dhcp' not in dict['address']: if 'dhcp_options' in dict: del dict['dhcp_options'] # Add interface instance name into dictionary dict.update({'ifname': ifname}) # Check if QoS policy applied on this interface - See ifconfig.interface.set_mirror_redirect() if config.exists(['qos', 'interface', ifname]): dict.update({'traffic_policy': {}}) address = leaf_node_changed(config, base + [ifname, 'address']) if address: dict.update({'address_old' : address}) # Check if we are a member of a bridge device bridge = is_member(config, ifname, 'bridge') if bridge: dict.update({'is_bridge_member' : bridge}) # Check if it is a monitor interface mirror = is_mirror_intf(config, ifname) if mirror: dict.update({'is_mirror_intf' : mirror}) # Check if we are a member of a bond device bond = is_member(config, ifname, 'bonding') if bond: dict.update({'is_bond_member' : bond}) # Check if any DHCP options changed which require a client restat dhcp = is_node_changed(config, base + [ifname, 'dhcp-options']) if dhcp: dict.update({'dhcp_options_changed' : {}}) # Changine interface VRF assignemnts require a DHCP restart, too dhcp = is_node_changed(config, base + [ifname, 'vrf']) if dhcp: dict.update({'dhcp_options_changed' : {}}) # Some interfaces come with a source_interface which must also not be part # of any other bond or bridge interface as it is exclusivly assigned as the # Kernels "lower" interface to this new "virtual/upper" interface. if 'source_interface' in dict: # Check if source interface is member of another bridge tmp = is_member(config, dict['source_interface'], 'bridge') if tmp: dict.update({'source_interface_is_bridge_member' : tmp}) # Check if source interface is member of another bridge tmp = is_member(config, dict['source_interface'], 'bonding') if tmp: dict.update({'source_interface_is_bond_member' : tmp}) mac = leaf_node_changed(config, base + [ifname, 'mac']) if mac: dict.update({'mac_old' : mac}) eui64 = leaf_node_changed(config, base + [ifname, 'ipv6', 'address', 'eui64']) if eui64: tmp = dict_search('ipv6.address', dict) if not tmp: dict.update({'ipv6': {'address': {'eui64_old': eui64}}}) else: dict['ipv6']['address'].update({'eui64_old': eui64}) for vif, vif_config in dict.get('vif', {}).items(): # Add subinterface name to dictionary dict['vif'][vif].update({'ifname' : f'{ifname}.{vif}'}) if config.exists(['qos', 'interface', f'{ifname}.{vif}']): dict['vif'][vif].update({'traffic_policy': {}}) if 'deleted' not in dict: address = leaf_node_changed(config, base + [ifname, 'vif', vif, 'address']) if address: dict['vif'][vif].update({'address_old' : address}) # If interface does not request an IPv4 DHCP address there is no need # to keep the dhcp-options key if 'address' not in dict['vif'][vif] or 'dhcp' not in dict['vif'][vif]['address']: if 'dhcp_options' in dict['vif'][vif]: del dict['vif'][vif]['dhcp_options'] # Check if we are a member of a bridge device bridge = is_member(config, f'{ifname}.{vif}', 'bridge') if bridge: dict['vif'][vif].update({'is_bridge_member' : bridge}) # Check if any DHCP options changed which require a client restat dhcp = is_node_changed(config, base + [ifname, 'vif', vif, 'dhcp-options']) if dhcp: dict['vif'][vif].update({'dhcp_options_changed' : {}}) for vif_s, vif_s_config in dict.get('vif_s', {}).items(): # Add subinterface name to dictionary dict['vif_s'][vif_s].update({'ifname' : f'{ifname}.{vif_s}'}) if config.exists(['qos', 'interface', f'{ifname}.{vif_s}']): dict['vif_s'][vif_s].update({'traffic_policy': {}}) if 'deleted' not in dict: address = leaf_node_changed(config, base + [ifname, 'vif-s', vif_s, 'address']) if address: dict['vif_s'][vif_s].update({'address_old' : address}) # If interface does not request an IPv4 DHCP address there is no need # to keep the dhcp-options key if 'address' not in dict['vif_s'][vif_s] or 'dhcp' not in \ dict['vif_s'][vif_s]['address']: if 'dhcp_options' in dict['vif_s'][vif_s]: del dict['vif_s'][vif_s]['dhcp_options'] # Check if we are a member of a bridge device bridge = is_member(config, f'{ifname}.{vif_s}', 'bridge') if bridge: dict['vif_s'][vif_s].update({'is_bridge_member' : bridge}) # Check if any DHCP options changed which require a client restat dhcp = is_node_changed(config, base + [ifname, 'vif-s', vif_s, 'dhcp-options']) if dhcp: dict['vif_s'][vif_s].update({'dhcp_options_changed' : {}}) for vif_c, vif_c_config in vif_s_config.get('vif_c', {}).items(): # Add subinterface name to dictionary dict['vif_s'][vif_s]['vif_c'][vif_c].update({'ifname' : f'{ifname}.{vif_s}.{vif_c}'}) if config.exists(['qos', 'interface', f'{ifname}.{vif_s}.{vif_c}']): dict['vif_s'][vif_s]['vif_c'][vif_c].update({'traffic_policy': {}}) if 'deleted' not in dict: address = leaf_node_changed(config, base + [ifname, 'vif-s', vif_s, 'vif-c', vif_c, 'address']) if address: dict['vif_s'][vif_s]['vif_c'][vif_c].update( {'address_old' : address}) # If interface does not request an IPv4 DHCP address there is no need # to keep the dhcp-options key if 'address' not in dict['vif_s'][vif_s]['vif_c'][vif_c] or 'dhcp' \ not in dict['vif_s'][vif_s]['vif_c'][vif_c]['address']: if 'dhcp_options' in dict['vif_s'][vif_s]['vif_c'][vif_c]: del dict['vif_s'][vif_s]['vif_c'][vif_c]['dhcp_options'] # Check if we are a member of a bridge device bridge = is_member(config, f'{ifname}.{vif_s}.{vif_c}', 'bridge') if bridge: dict['vif_s'][vif_s]['vif_c'][vif_c].update( {'is_bridge_member' : bridge}) # Check if any DHCP options changed which require a client restat dhcp = is_node_changed(config, base + [ifname, 'vif-s', vif_s, 'vif-c', vif_c, 'dhcp-options']) if dhcp: dict['vif_s'][vif_s]['vif_c'][vif_c].update({'dhcp_options_changed' : {}}) # Check vif, vif-s/vif-c VLAN interfaces for removal dict = get_removed_vlans(config, base + [ifname], dict) return ifname, dict def get_vlan_ids(interface): """ Get the VLAN ID of the interface bound to the bridge """ vlan_ids = set() bridge_status = cmd('bridge -j vlan show', shell=True) vlan_filter_status = json.loads(bridge_status) if vlan_filter_status is not None: for interface_status in vlan_filter_status: ifname = interface_status['ifname'] if interface == ifname: vlans_status = interface_status['vlans'] for vlan_status in vlans_status: vlan_id = vlan_status['vlan'] vlan_ids.add(vlan_id) return vlan_ids def get_accel_dict(config, base, chap_secrets): """ Common utility function to retrieve and mangle the Accel-PPP configuration from different CLI input nodes. All Accel-PPP services have a common base where value retrival is identical. This function must be used whenever possible when working with Accel-PPP services! Return a dictionary with the necessary interface config keys. """ from vyos.utils.system import get_half_cpus from vyos.template import is_ipv4 dict = config.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True, with_recursive_defaults=True) # set CPUs cores to process requests dict.update({'thread_count' : get_half_cpus()}) # we need to store the path to the secrets file dict.update({'chap_secrets_file' : chap_secrets}) # We can only have two IPv4 and three IPv6 nameservers - also they are # configured in a different way in the configuration, this is why we split # the configuration if 'name_server' in dict: ns_v4 = [] ns_v6 = [] for ns in dict['name_server']: if is_ipv4(ns): ns_v4.append(ns) else: ns_v6.append(ns) dict.update({'name_server_ipv4' : ns_v4, 'name_server_ipv6' : ns_v6}) del dict['name_server'] # Check option "disable-accounting" per server and replace default value from '1813' to '0' for server in (dict_search('authentication.radius.server', dict) or []): if 'disable_accounting' in dict['authentication']['radius']['server'][server]: dict['authentication']['radius']['server'][server]['acct_port'] = '0' return dict diff --git a/python/vyos/configdiff.py b/python/vyos/configdiff.py index 0caa204c3..1ec2dfafe 100644 --- a/python/vyos/configdiff.py +++ b/python/vyos/configdiff.py @@ -1,367 +1,375 @@ # Copyright 2020 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see <http://www.gnu.org/licenses/>. from enum import IntFlag, auto from vyos.config import Config from vyos.configtree import DiffTree from vyos.configdict import dict_merge from vyos.configdict import list_diff from vyos.utils.dict import get_sub_dict from vyos.utils.dict import mangle_dict_keys from vyos.utils.dict import dict_search_args -from vyos.xml import defaults +from vyos.xml_ref import get_defaults class ConfigDiffError(Exception): """ Raised on config dict access errors, for example, calling get_value on a non-leaf node. """ pass def enum_to_key(e): return e.name.lower() class Diff(IntFlag): MERGE = auto() DELETE = auto() ADD = auto() STABLE = auto() ALL = Diff.MERGE | Diff.DELETE | Diff.ADD | Diff.STABLE requires_effective = [enum_to_key(Diff.DELETE)] target_defaults = [enum_to_key(Diff.MERGE)] def _key_sets_from_dicts(session_dict, effective_dict): session_keys = list(session_dict) effective_keys = list(effective_dict) ret = {} stable_keys = [k for k in session_keys if k in effective_keys] ret[enum_to_key(Diff.MERGE)] = session_keys ret[enum_to_key(Diff.DELETE)] = [k for k in effective_keys if k not in stable_keys] ret[enum_to_key(Diff.ADD)] = [k for k in session_keys if k not in stable_keys] ret[enum_to_key(Diff.STABLE)] = stable_keys return ret def _dict_from_key_set(key_set, d): # This will always be applied to a key_set obtained from a get_sub_dict, # hence there is no possibility of KeyError, as get_sub_dict guarantees # a return type of dict ret = {k: d[k] for k in key_set} return ret def get_config_diff(config, key_mangling=None): """ Check type and return ConfigDiff instance. """ if not config or not isinstance(config, Config): raise TypeError("argument must me a Config instance") if key_mangling and not (isinstance(key_mangling, tuple) and \ (len(key_mangling) == 2) and \ isinstance(key_mangling[0], str) and \ isinstance(key_mangling[1], str)): raise ValueError("key_mangling must be a tuple of two strings") if hasattr(config, 'cached_diff_tree'): diff_t = getattr(config, 'cached_diff_tree') else: diff_t = DiffTree(config._running_config, config._session_config) setattr(config, 'cached_diff_tree', diff_t) if hasattr(config, 'cached_diff_dict'): diff_d = getattr(config, 'cached_diff_dict') else: diff_d = diff_t.dict setattr(config, 'cached_diff_dict', diff_d) return ConfigDiff(config, key_mangling, diff_tree=diff_t, diff_dict=diff_d) class ConfigDiff(object): """ The class of config changes as represented by comparison between the session config dict and the effective config dict. """ def __init__(self, config, key_mangling=None, diff_tree=None, diff_dict=None): self._level = config.get_level() self._session_config_dict = config.get_cached_root_dict(effective=False) self._effective_config_dict = config.get_cached_root_dict(effective=True) self._key_mangling = key_mangling self._diff_tree = diff_tree self._diff_dict = diff_dict # mirrored from Config; allow path arguments relative to level def _make_path(self, path): if isinstance(path, str): path = path.split() elif isinstance(path, list): pass else: raise TypeError("Path must be a whitespace-separated string or a list") ret = self._level + path return ret def set_level(self, path): """ Set the *edit level*, that is, a relative config dict path. Once set, all operations will be relative to this path, for example, after ``set_level("system")``, calling ``get_value("name-server")`` is equivalent to calling ``get_value("system name-server")`` without ``set_level``. Args: path (str|list): relative config path """ if isinstance(path, str): if path: self._level = path.split() else: self._level = [] elif isinstance(path, list): self._level = path.copy() else: raise TypeError("Level path must be either a whitespace-separated string or a list") def get_level(self): """ Gets the current edit level. Returns: str: current edit level """ ret = self._level.copy() return ret def _mangle_dict_keys(self, config_dict): config_dict = mangle_dict_keys(config_dict, self._key_mangling[0], self._key_mangling[1]) return config_dict def is_node_changed(self, path=[]): if self._diff_tree is None: raise NotImplementedError("diff_tree class not available") if (self._diff_tree.add.exists(self._make_path(path)) or self._diff_tree.sub.exists(self._make_path(path))): return True return False def get_child_nodes_diff_str(self, path=[]): ret = {'add': {}, 'change': {}, 'delete': {}} diff = self.get_child_nodes_diff(path, expand_nodes=Diff.ADD | Diff.DELETE | Diff.MERGE | Diff.STABLE, no_defaults=True) def parse_dict(diff_dict, diff_type, prefix=[]): for k, v in diff_dict.items(): if isinstance(v, dict): parse_dict(v, diff_type, prefix + [k]) else: path_str = ' '.join(prefix + [k]) if diff_type == 'add' or diff_type == 'delete': if isinstance(v, list): v = ', '.join(v) ret[diff_type][path_str] = v elif diff_type == 'merge': old_value = dict_search_args(diff['stable'], *prefix, k) if old_value and old_value != v: ret['change'][path_str] = [old_value, v] parse_dict(diff['merge'], 'merge') parse_dict(diff['add'], 'add') parse_dict(diff['delete'], 'delete') return ret def get_child_nodes_diff(self, path=[], expand_nodes=Diff(0), no_defaults=False, recursive=False): """ Args: path (str|list): config path expand_nodes=Diff(0): bit mask of enum indicating for which nodes to provide full dict; for example, Diff.MERGE will expand dict['merge'] into dict under value no_detaults=False: if expand_nodes & Diff.MERGE, do not merge default values to ret['merge'] recursive: if true, use config_tree diff algorithm provided by diff_tree class Returns: dict of lists, representing differences between session and effective config, under path dict['merge'] = session config values dict['delete'] = effective config values, not in session dict['add'] = session config values, not in effective dict['stable'] = config values in both session and effective """ session_dict = get_sub_dict(self._session_config_dict, self._make_path(path), get_first_key=True) if recursive: if self._diff_tree is None: raise NotImplementedError("diff_tree class not available") else: add = get_sub_dict(self._diff_dict, ['add'], get_first_key=True) sub = get_sub_dict(self._diff_dict, ['sub'], get_first_key=True) inter = get_sub_dict(self._diff_dict, ['inter'], get_first_key=True) ret = {} ret[enum_to_key(Diff.MERGE)] = session_dict ret[enum_to_key(Diff.DELETE)] = get_sub_dict(sub, self._make_path(path), get_first_key=True) ret[enum_to_key(Diff.ADD)] = get_sub_dict(add, self._make_path(path), get_first_key=True) ret[enum_to_key(Diff.STABLE)] = get_sub_dict(inter, self._make_path(path), get_first_key=True) for e in Diff: k = enum_to_key(e) if not (e & expand_nodes): ret[k] = list(ret[k]) else: if self._key_mangling: ret[k] = self._mangle_dict_keys(ret[k]) if k in target_defaults and not no_defaults: - default_values = defaults(self._make_path(path)) + default_values = get_defaults(self._make_path(path), + get_first_key=True, + recursive=True) ret[k] = dict_merge(default_values, ret[k]) return ret effective_dict = get_sub_dict(self._effective_config_dict, self._make_path(path), get_first_key=True) ret = _key_sets_from_dicts(session_dict, effective_dict) if not expand_nodes: return ret for e in Diff: if expand_nodes & e: k = enum_to_key(e) if k in requires_effective: ret[k] = _dict_from_key_set(ret[k], effective_dict) else: ret[k] = _dict_from_key_set(ret[k], session_dict) if self._key_mangling: ret[k] = self._mangle_dict_keys(ret[k]) if k in target_defaults and not no_defaults: - default_values = defaults(self._make_path(path)) + default_values = get_defaults(self._make_path(path), + get_first_key=True, + recursive=True) ret[k] = dict_merge(default_values, ret[k]) return ret def get_node_diff(self, path=[], expand_nodes=Diff(0), no_defaults=False, recursive=False): """ Args: path (str|list): config path expand_nodes=Diff(0): bit mask of enum indicating for which nodes to provide full dict; for example, Diff.MERGE will expand dict['merge'] into dict under value no_detaults=False: if expand_nodes & Diff.MERGE, do not merge default values to ret['merge'] recursive: if true, use config_tree diff algorithm provided by diff_tree class Returns: dict of lists, representing differences between session and effective config, at path dict['merge'] = session config values dict['delete'] = effective config values, not in session dict['add'] = session config values, not in effective dict['stable'] = config values in both session and effective """ session_dict = get_sub_dict(self._session_config_dict, self._make_path(path)) if recursive: if self._diff_tree is None: raise NotImplementedError("diff_tree class not available") else: add = get_sub_dict(self._diff_dict, ['add'], get_first_key=True) sub = get_sub_dict(self._diff_dict, ['sub'], get_first_key=True) inter = get_sub_dict(self._diff_dict, ['inter'], get_first_key=True) ret = {} ret[enum_to_key(Diff.MERGE)] = session_dict ret[enum_to_key(Diff.DELETE)] = get_sub_dict(sub, self._make_path(path)) ret[enum_to_key(Diff.ADD)] = get_sub_dict(add, self._make_path(path)) ret[enum_to_key(Diff.STABLE)] = get_sub_dict(inter, self._make_path(path)) for e in Diff: k = enum_to_key(e) if not (e & expand_nodes): ret[k] = list(ret[k]) else: if self._key_mangling: ret[k] = self._mangle_dict_keys(ret[k]) if k in target_defaults and not no_defaults: - default_values = defaults(self._make_path(path)) + default_values = get_defaults(self._make_path(path), + get_first_key=True, + recursive=True) ret[k] = dict_merge(default_values, ret[k]) return ret effective_dict = get_sub_dict(self._effective_config_dict, self._make_path(path)) ret = _key_sets_from_dicts(session_dict, effective_dict) if not expand_nodes: return ret for e in Diff: if expand_nodes & e: k = enum_to_key(e) if k in requires_effective: ret[k] = _dict_from_key_set(ret[k], effective_dict) else: ret[k] = _dict_from_key_set(ret[k], session_dict) if self._key_mangling: ret[k] = self._mangle_dict_keys(ret[k]) if k in target_defaults and not no_defaults: - default_values = defaults(self._make_path(path)) + default_values = get_defaults(self._make_path(path), + get_first_key=True, + recursive=True) ret[k] = dict_merge(default_values, ret[k]) return ret def get_value_diff(self, path=[]): """ Args: path (str|list): config path Returns: (new, old) tuple of values in session config/effective config """ # one should properly use is_leaf as check; for the moment we will # deduce from type, which will not catch call on non-leaf node if None new_value_dict = get_sub_dict(self._session_config_dict, self._make_path(path)) old_value_dict = get_sub_dict(self._effective_config_dict, self._make_path(path)) new_value = None old_value = None if new_value_dict: new_value = next(iter(new_value_dict.values())) if old_value_dict: old_value = next(iter(old_value_dict.values())) if new_value and isinstance(new_value, dict): raise ConfigDiffError("get_value_changed called on non-leaf node") if old_value and isinstance(old_value, dict): raise ConfigDiffError("get_value_changed called on non-leaf node") return new_value, old_value diff --git a/python/vyos/xml_ref/definition.py b/python/vyos/xml_ref/definition.py index 38e07f0a7..c90c5ddbc 100644 --- a/python/vyos/xml_ref/definition.py +++ b/python/vyos/xml_ref/definition.py @@ -1,302 +1,302 @@ # Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see <http://www.gnu.org/licenses/>. from typing import Optional, Union, Any, TYPE_CHECKING # https://peps.python.org/pep-0484/#forward-references # for type 'ConfigDict' if TYPE_CHECKING: from vyos.config import ConfigDict def set_source_recursive(o: Union[dict, str, list], b: bool): d = {} if not isinstance(o, dict): d = {'_source': b} else: for k, v in o.items(): d[k] = set_source_recursive(v, b) d |= {'_source': b} return d def source_dict_merge(src: dict, dest: dict): from copy import deepcopy dst = deepcopy(dest) from_src = {} for key, value in src.items(): if key not in dst: dst[key] = value from_src[key] = set_source_recursive(value, True) elif isinstance(src[key], dict): dst[key], f = source_dict_merge(src[key], dst[key]) f |= {'_source': False} from_src[key] = f return dst, from_src def ext_dict_merge(src: dict, dest: Union[dict, 'ConfigDict']): d, f = source_dict_merge(src, dest) if hasattr(d, '_from_defaults'): setattr(d, '_from_defaults', f) return d def from_source(d: dict, path: list) -> bool: for key in path: d = d[key] if key in d else {} if not d or not isinstance(d, dict): return False return d.get('_source', False) class Xml: def __init__(self): self.ref = {} def define(self, ref: dict): self.ref = ref def _get_ref_node_data(self, node: dict, data: str) -> Union[bool, str]: res = node.get('node_data', {}) if not res: raise ValueError("non-existent node data") if data not in res: raise ValueError("non-existent data field") return res.get(data) def _get_ref_path(self, path: list) -> dict: ref_path = path.copy() d = self.ref while ref_path and d: d = d.get(ref_path[0], {}) ref_path.pop(0) if self._is_tag_node(d) and ref_path: ref_path.pop(0) return d def _is_tag_node(self, node: dict) -> bool: res = self._get_ref_node_data(node, 'node_type') return res == 'tag' def is_tag(self, path: list) -> bool: ref_path = path.copy() d = self.ref while ref_path and d: d = d.get(ref_path[0], {}) ref_path.pop(0) if self._is_tag_node(d) and ref_path: if len(ref_path) == 1: return False ref_path.pop(0) return self._is_tag_node(d) def is_tag_value(self, path: list) -> bool: if len(path) < 2: return False return self.is_tag(path[:-1]) def _is_multi_node(self, node: dict) -> bool: b = self._get_ref_node_data(node, 'multi') assert isinstance(b, bool) return b def is_multi(self, path: list) -> bool: d = self._get_ref_path(path) return self._is_multi_node(d) def _is_valueless_node(self, node: dict) -> bool: b = self._get_ref_node_data(node, 'valueless') assert isinstance(b, bool) return b def is_valueless(self, path: list) -> bool: d = self._get_ref_path(path) return self._is_valueless_node(d) def _is_leaf_node(self, node: dict) -> bool: res = self._get_ref_node_data(node, 'node_type') return res == 'leaf' def is_leaf(self, path: list) -> bool: d = self._get_ref_path(path) return self._is_leaf_node(d) @staticmethod def _dict_get(d: dict, path: list) -> dict: for i in path: d = d.get(i, {}) if not isinstance(d, dict): return {} if not d: break return d def _dict_find(self, d: dict, key: str, non_local=False) -> bool: for k in list(d): if k in ('node_data', 'component_version'): continue if k == key: return True if non_local and isinstance(d[k], dict): if self._dict_find(d[k], key): return True return False def cli_defined(self, path: list, node: str, non_local=False) -> bool: d = self._dict_get(self.ref, path) return self._dict_find(d, node, non_local=non_local) def component_version(self) -> dict: d = {} - for k, v in self.ref['component_version']: + for k, v in self.ref['component_version'].items(): d[k] = int(v) return d def multi_to_list(self, rpath: list, conf: dict) -> dict: res: Any = {} for k in list(conf): d = self._get_ref_path(rpath + [k]) if self._is_leaf_node(d): if self._is_multi_node(d) and not isinstance(conf[k], list): res[k] = [conf[k]] else: res[k] = conf[k] else: res[k] = self.multi_to_list(rpath + [k], conf[k]) return res def _get_default_value(self, node: dict) -> Optional[str]: return self._get_ref_node_data(node, "default_value") def _get_default(self, node: dict) -> Optional[Union[str, list]]: default = self._get_default_value(node) if default is None: return None if self._is_multi_node(node): return default.split() return default def default_value(self, path: list) -> Optional[Union[str, list]]: d = self._get_ref_path(path) default = self._get_default_value(d) if default is None: return None if self._is_multi_node(d) or self._is_tag_node(d): return default.split() return default def get_defaults(self, path: list, get_first_key=False, recursive=False) -> dict: """Return dict containing default values below path Note that descent below path will not proceed beyond an encountered tag node, as no tag node value is known. For a default dict relative to an existing config dict containing tag node values, see function: 'relative_defaults' """ res: dict = {} if self.is_tag(path): return res d = self._get_ref_path(path) if self._is_leaf_node(d): default_value = self._get_default(d) if default_value is not None: return {path[-1]: default_value} if path else {} for k in list(d): if k in ('node_data', 'component_version') : continue if self._is_leaf_node(d[k]): default_value = self._get_default(d[k]) if default_value is not None: res |= {k: default_value} elif self.is_tag(path + [k]): # tag node defaults are used as suggestion, not default value; # should this change, append to path and continue if recursive pass else: if recursive: pos = self.get_defaults(path + [k], recursive=True) res |= pos if res: if get_first_key or not path: return res return {path[-1]: res} return {} def _well_defined(self, path: list, conf: dict) -> bool: # test disjoint path + conf for sensible config paths def step(c): return [next(iter(c.keys()))] if c else [] try: tmp = step(conf) if tmp and self.is_tag_value(path + tmp): c = conf[tmp[0]] if not isinstance(c, dict): raise ValueError tmp = tmp + step(c) self._get_ref_path(path + tmp) else: self._get_ref_path(path + tmp) except ValueError: return False return True def _relative_defaults(self, rpath: list, conf: dict, recursive=False) -> dict: res: dict = {} res = self.get_defaults(rpath, recursive=recursive, get_first_key=True) for k in list(conf): if isinstance(conf[k], dict): step = self._relative_defaults(rpath + [k], conf=conf[k], recursive=recursive) res |= step if res: return {rpath[-1]: res} if rpath else res return {} def relative_defaults(self, path: list, conf: dict, get_first_key=False, recursive=False) -> dict: """Return dict containing defaults along paths of a config dict """ if not conf: return self.get_defaults(path, get_first_key=get_first_key, recursive=recursive) if not self._well_defined(path, conf): # adjust for possible overlap: if path and path[-1] in list(conf): conf = conf[path[-1]] conf = {} if not isinstance(conf, dict) else conf if not self._well_defined(path, conf): print('path to config dict does not define full config paths') return {} res = self._relative_defaults(path, conf, recursive=recursive) if get_first_key and path: if res.values(): res = next(iter(res.values())) else: res = {} return res diff --git a/src/helpers/vyos-domain-resolver.py b/src/helpers/vyos-domain-resolver.py index 2036ca72e..7e2fe2462 100755 --- a/src/helpers/vyos-domain-resolver.py +++ b/src/helpers/vyos-domain-resolver.py @@ -1,183 +1,177 @@ #!/usr/bin/env python3 # # Copyright (C) 2022-2023 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import json import os import time from vyos.configdict import dict_merge from vyos.configquery import ConfigTreeQuery from vyos.firewall import fqdn_config_parse from vyos.firewall import fqdn_resolve from vyos.utils.commit import commit_in_progress from vyos.utils.dict import dict_search_args from vyos.utils.process import cmd from vyos.utils.process import run -from vyos.xml import defaults +from vyos.xml_ref import get_defaults base = ['firewall'] timeout = 300 cache = False domain_state = {} ipv4_tables = { 'ip vyos_mangle', 'ip vyos_filter', 'ip vyos_nat' } ipv6_tables = { 'ip6 vyos_mangle', 'ip6 vyos_filter' } def get_config(conf): firewall = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) - default_values = defaults(base) - for tmp in ['name', 'ipv6_name']: - if tmp in default_values: - del default_values[tmp] - - if 'zone' in default_values: - del default_values['zone'] + default_values = get_defaults(base, get_first_key=True) firewall = dict_merge(default_values, firewall) global timeout, cache if 'resolver_interval' in firewall: timeout = int(firewall['resolver_interval']) if 'resolver_cache' in firewall: cache = True fqdn_config_parse(firewall) return firewall def resolve(domains, ipv6=False): global domain_state ip_list = set() for domain in domains: resolved = fqdn_resolve(domain, ipv6=ipv6) if resolved and cache: domain_state[domain] = resolved elif not resolved: if domain not in domain_state: continue resolved = domain_state[domain] ip_list = ip_list | resolved return ip_list def nft_output(table, set_name, ip_list): output = [f'flush set {table} {set_name}'] if ip_list: ip_str = ','.join(ip_list) output.append(f'add element {table} {set_name} {{ {ip_str} }}') return output def nft_valid_sets(): try: valid_sets = [] sets_json = cmd('nft -j list sets') sets_obj = json.loads(sets_json) for obj in sets_obj['nftables']: if 'set' in obj: family = obj['set']['family'] table = obj['set']['table'] name = obj['set']['name'] valid_sets.append((f'{family} {table}', name)) return valid_sets except: return [] def update(firewall): conf_lines = [] count = 0 valid_sets = nft_valid_sets() domain_groups = dict_search_args(firewall, 'group', 'domain_group') if domain_groups: for set_name, domain_config in domain_groups.items(): if 'address' not in domain_config: continue nft_set_name = f'D_{set_name}' domains = domain_config['address'] ip_list = resolve(domains, ipv6=False) for table in ipv4_tables: if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) ip6_list = resolve(domains, ipv6=True) for table in ipv6_tables: if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip6_list) count += 1 for set_name, domain in firewall['ip_fqdn'].items(): table = 'ip vyos_filter' nft_set_name = f'FQDN_{set_name}' ip_list = resolve([domain], ipv6=False) if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) count += 1 for set_name, domain in firewall['ip6_fqdn'].items(): table = 'ip6 vyos_filter' nft_set_name = f'FQDN_{set_name}' ip_list = resolve([domain], ipv6=True) if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) count += 1 nft_conf_str = "\n".join(conf_lines) + "\n" code = run(f'nft -f -', input=nft_conf_str) print(f'Updated {count} sets - result: {code}') if __name__ == '__main__': print(f'VyOS domain resolver') count = 1 while commit_in_progress(): if ( count % 60 == 0 ): print(f'Commit still in progress after {count}s - waiting') count += 1 time.sleep(1) conf = ConfigTreeQuery() firewall = get_config(conf) print(f'interval: {timeout}s - cache: {cache}') while True: update(firewall) time.sleep(timeout) diff --git a/src/op_mode/pki.py b/src/op_mode/pki.py index f638c51bc..aff4ad1ae 100755 --- a/src/op_mode/pki.py +++ b/src/op_mode/pki.py @@ -1,1080 +1,1077 @@ #!/usr/bin/env python3 # # Copyright (C) 2021-2023 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import argparse import ipaddress import os import re import sys import tabulate from cryptography import x509 from cryptography.x509.oid import ExtendedKeyUsageOID from vyos.config import Config from vyos.configquery import ConfigTreeQuery -from vyos.configdict import dict_merge from vyos.pki import encode_certificate, encode_public_key, encode_private_key, encode_dh_parameters from vyos.pki import get_certificate_fingerprint from vyos.pki import create_certificate, create_certificate_request, create_certificate_revocation_list from vyos.pki import create_private_key from vyos.pki import create_dh_parameters from vyos.pki import load_certificate, load_certificate_request, load_private_key from vyos.pki import load_crl, load_dh_parameters, load_public_key from vyos.pki import verify_certificate from vyos.utils.io import ask_input from vyos.utils.io import ask_yes_no from vyos.utils.misc import install_into_config from vyos.utils.process import cmd -from vyos.xml import defaults CERT_REQ_END = '-----END CERTIFICATE REQUEST-----' auth_dir = '/config/auth' # Helper Functions conf = ConfigTreeQuery() def get_default_values(): # Fetch default x509 values base = ['pki', 'x509', 'default'] x509_defaults = conf.get_config_dict(base, key_mangling=('-', '_'), + no_tag_node_value_mangle=True, get_first_key=True, - no_tag_node_value_mangle=True) - default_values = defaults(base) - x509_defaults = dict_merge(default_values, x509_defaults) + with_recursive_defaults=True) return x509_defaults def get_config_ca_certificate(name=None): # Fetch ca certificates from config base = ['pki', 'ca'] if not conf.exists(base): return False if name: base = base + [name] if not conf.exists(base + ['private', 'key']) or not conf.exists(base + ['certificate']): return False return conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) def get_config_certificate(name=None): # Get certificates from config base = ['pki', 'certificate'] if not conf.exists(base): return False if name: base = base + [name] if not conf.exists(base + ['private', 'key']) or not conf.exists(base + ['certificate']): return False return conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) def get_certificate_ca(cert, ca_certs): # Find CA certificate for given certificate if not ca_certs: return None for ca_name, ca_dict in ca_certs.items(): if 'certificate' not in ca_dict: continue ca_cert = load_certificate(ca_dict['certificate']) if not ca_cert: continue if verify_certificate(cert, ca_cert): return ca_name return None def get_config_revoked_certificates(): # Fetch revoked certificates from config ca_base = ['pki', 'ca'] cert_base = ['pki', 'certificate'] certs = [] if conf.exists(ca_base): ca_certificates = conf.get_config_dict(ca_base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) certs.extend(ca_certificates.values()) if conf.exists(cert_base): certificates = conf.get_config_dict(cert_base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) certs.extend(certificates.values()) return [cert_dict for cert_dict in certs if 'revoke' in cert_dict] def get_revoked_by_serial_numbers(serial_numbers=[]): # Return serial numbers of revoked certificates certs_out = [] certs = get_config_certificate() ca_certs = get_config_ca_certificate() if certs: for cert_name, cert_dict in certs.items(): if 'certificate' not in cert_dict: continue cert = load_certificate(cert_dict['certificate']) if cert.serial_number in serial_numbers: certs_out.append(cert_name) if ca_certs: for cert_name, cert_dict in ca_certs.items(): if 'certificate' not in cert_dict: continue cert = load_certificate(cert_dict['certificate']) if cert.serial_number in serial_numbers: certs_out.append(cert_name) return certs_out def install_certificate(name, cert='', private_key=None, key_type=None, key_passphrase=None, is_ca=False): # Show/install conf commands for certificate prefix = 'ca' if is_ca else 'certificate' base = f"pki {prefix} {name}" config_paths = [] if cert: cert_pem = "".join(encode_certificate(cert).strip().split("\n")[1:-1]) config_paths.append(f"{base} certificate '{cert_pem}'") if private_key: key_pem = "".join(encode_private_key(private_key, passphrase=key_passphrase).strip().split("\n")[1:-1]) config_paths.append(f"{base} private key '{key_pem}'") if key_passphrase: config_paths.append(f"{base} private password-protected") install_into_config(conf, config_paths) def install_crl(ca_name, crl): # Show/install conf commands for crl crl_pem = "".join(encode_certificate(crl).strip().split("\n")[1:-1]) install_into_config(conf, [f"pki ca {ca_name} crl '{crl_pem}'"]) def install_dh_parameters(name, params): # Show/install conf commands for dh params dh_pem = "".join(encode_dh_parameters(params).strip().split("\n")[1:-1]) install_into_config(conf, [f"pki dh {name} parameters '{dh_pem}'"]) def install_ssh_key(name, public_key, private_key, passphrase=None): # Show/install conf commands for ssh key key_openssh = encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH') username = os.getlogin() type_key_split = key_openssh.split(" ") base = f"system login user {username} authentication public-keys {name}" install_into_config(conf, [ f"{base} key '{type_key_split[1]}'", f"{base} type '{type_key_split[0]}'" ]) print(encode_private_key(private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase)) def install_keypair(name, key_type, private_key=None, public_key=None, passphrase=None, prompt=True): # Show/install conf commands for key-pair config_paths = [] if public_key: install_public_key = not prompt or ask_yes_no('Do you want to install the public key?', default=True) public_key_pem = encode_public_key(public_key) if install_public_key: install_public_pem = "".join(public_key_pem.strip().split("\n")[1:-1]) config_paths.append(f"pki key-pair {name} public key '{install_public_pem}'") else: print("Public key:") print(public_key_pem) if private_key: install_private_key = not prompt or ask_yes_no('Do you want to install the private key?', default=True) private_key_pem = encode_private_key(private_key, passphrase=passphrase) if install_private_key: install_private_pem = "".join(private_key_pem.strip().split("\n")[1:-1]) config_paths.append(f"pki key-pair {name} private key '{install_private_pem}'") if passphrase: config_paths.append(f"pki key-pair {name} private password-protected") else: print("Private key:") print(private_key_pem) install_into_config(conf, config_paths) def install_openvpn_key(name, key_data, key_version='1'): config_paths = [ f"pki openvpn shared-secret {name} key '{key_data}'", f"pki openvpn shared-secret {name} version '{key_version}'" ] install_into_config(conf, config_paths) def install_wireguard_key(interface, private_key, public_key): # Show conf commands for installing wireguard key pairs from vyos.ifconfig import Section if Section.section(interface) != 'wireguard': print(f'"{interface}" is not a WireGuard interface name!') exit(1) # Check if we are running in a config session - if yes, we can directly write to the CLI install_into_config(conf, [f"interfaces wireguard {interface} private-key '{private_key}'"]) print(f"Corresponding public-key to use on peer system is: '{public_key}'") def install_wireguard_psk(interface, peer, psk): from vyos.ifconfig import Section if Section.section(interface) != 'wireguard': print(f'"{interface}" is not a WireGuard interface name!') exit(1) # Check if we are running in a config session - if yes, we can directly write to the CLI install_into_config(conf, [f"interfaces wireguard {interface} peer {peer} preshared-key '{psk}'"]) def ask_passphrase(): passphrase = None print("Note: If you plan to use the generated key on this router, do not encrypt the private key.") if ask_yes_no('Do you want to encrypt the private key with a passphrase?'): passphrase = ask_input('Enter passphrase:') return passphrase def write_file(filename, contents): full_path = os.path.join(auth_dir, filename) directory = os.path.dirname(full_path) if not os.path.exists(directory): print('Failed to write file: directory does not exist') return False if os.path.exists(full_path) and not ask_yes_no('Do you want to overwrite the existing file?'): return False with open(full_path, 'w') as f: f.write(contents) print(f'File written to {full_path}') # Generation functions def generate_private_key(): key_type = ask_input('Enter private key type: [rsa, dsa, ec]', default='rsa', valid_responses=['rsa', 'dsa', 'ec']) size_valid = [] size_default = 0 if key_type in ['rsa', 'dsa']: size_default = 2048 size_valid = [512, 1024, 2048, 4096] elif key_type == 'ec': size_default = 256 size_valid = [224, 256, 384, 521] size = ask_input('Enter private key bits:', default=size_default, numeric_only=True, valid_responses=size_valid) return create_private_key(key_type, size), key_type def parse_san_string(san_string): if not san_string: return None output = [] san_split = san_string.strip().split(",") for pair_str in san_split: tag, value = pair_str.strip().split(":", 1) if tag == 'ipv4': output.append(ipaddress.IPv4Address(value)) elif tag == 'ipv6': output.append(ipaddress.IPv6Address(value)) elif tag == 'dns': output.append(value) return output def generate_certificate_request(private_key=None, key_type=None, return_request=False, name=None, install=False, file=False, ask_san=True): if not private_key: private_key, key_type = generate_private_key() default_values = get_default_values() subject = {} subject['country'] = ask_input('Enter country code:', default=default_values['country']) subject['state'] = ask_input('Enter state:', default=default_values['state']) subject['locality'] = ask_input('Enter locality:', default=default_values['locality']) subject['organization'] = ask_input('Enter organization name:', default=default_values['organization']) subject['common_name'] = ask_input('Enter common name:', default='vyos.io') subject_alt_names = None if ask_san and ask_yes_no('Do you want to configure Subject Alternative Names?'): print("Enter alternative names in a comma separate list, example: ipv4:1.1.1.1,ipv6:fe80::1,dns:vyos.net") san_string = ask_input('Enter Subject Alternative Names:') subject_alt_names = parse_san_string(san_string) cert_req = create_certificate_request(subject, private_key, subject_alt_names) if return_request: return cert_req passphrase = ask_passphrase() if not install and not file: print(encode_certificate(cert_req)) print(encode_private_key(private_key, passphrase=passphrase)) return None if install: print("Certificate request:") print(encode_certificate(cert_req) + "\n") install_certificate(name, private_key=private_key, key_type=key_type, key_passphrase=passphrase, is_ca=False) if file: write_file(f'{name}.csr', encode_certificate(cert_req)) write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) def generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=False, is_sub_ca=False): valid_days = ask_input('Enter how many days certificate will be valid:', default='365' if not is_ca else '1825', numeric_only=True) cert_type = None if not is_ca: cert_type = ask_input('Enter certificate type: (client, server)', default='server', valid_responses=['client', 'server']) return create_certificate(cert_req, ca_cert, ca_private_key, valid_days, cert_type, is_ca, is_sub_ca) def generate_ca_certificate(name, install=False, file=False): private_key, key_type = generate_private_key() cert_req = generate_certificate_request(private_key, key_type, return_request=True, ask_san=False) cert = generate_certificate(cert_req, cert_req, private_key, is_ca=True) passphrase = ask_passphrase() if not install and not file: print(encode_certificate(cert)) print(encode_private_key(private_key, passphrase=passphrase)) return None if install: install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True) if file: write_file(f'{name}.pem', encode_certificate(cert)) write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) def generate_ca_certificate_sign(name, ca_name, install=False, file=False): ca_dict = get_config_ca_certificate(ca_name) if not ca_dict: print(f"CA certificate or private key for '{ca_name}' not found") return None ca_cert = load_certificate(ca_dict['certificate']) if not ca_cert: print("Failed to load signing CA certificate, aborting") return None ca_private = ca_dict['private'] ca_private_passphrase = None if 'password_protected' in ca_private: ca_private_passphrase = ask_input('Enter signing CA private key passphrase:') ca_private_key = load_private_key(ca_private['key'], passphrase=ca_private_passphrase) if not ca_private_key: print("Failed to load signing CA private key, aborting") return None private_key = None key_type = None cert_req = None if not ask_yes_no('Do you already have a certificate request?'): private_key, key_type = generate_private_key() cert_req = generate_certificate_request(private_key, key_type, return_request=True, ask_san=False) else: print("Paste certificate request and press enter:") lines = [] curr_line = '' while True: curr_line = input().strip() if not curr_line or curr_line == CERT_REQ_END: break lines.append(curr_line) if not lines: print("Aborted") return None wrap = lines[0].find('-----') < 0 # Only base64 pasted, add the CSR tags for parsing cert_req = load_certificate_request("\n".join(lines), wrap) if not cert_req: print("Invalid certificate request") return None cert = generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=True, is_sub_ca=True) passphrase = ask_passphrase() if not install and not file: print(encode_certificate(cert)) print(encode_private_key(private_key, passphrase=passphrase)) return None if install: install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True) if file: write_file(f'{name}.pem', encode_certificate(cert)) write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) def generate_certificate_sign(name, ca_name, install=False, file=False): ca_dict = get_config_ca_certificate(ca_name) if not ca_dict: print(f"CA certificate or private key for '{ca_name}' not found") return None ca_cert = load_certificate(ca_dict['certificate']) if not ca_cert: print("Failed to load CA certificate, aborting") return None ca_private = ca_dict['private'] ca_private_passphrase = None if 'password_protected' in ca_private: ca_private_passphrase = ask_input('Enter CA private key passphrase:') ca_private_key = load_private_key(ca_private['key'], passphrase=ca_private_passphrase) if not ca_private_key: print("Failed to load CA private key, aborting") return None private_key = None key_type = None cert_req = None if not ask_yes_no('Do you already have a certificate request?'): private_key, key_type = generate_private_key() cert_req = generate_certificate_request(private_key, key_type, return_request=True) else: print("Paste certificate request and press enter:") lines = [] curr_line = '' while True: curr_line = input().strip() if not curr_line or curr_line == CERT_REQ_END: break lines.append(curr_line) if not lines: print("Aborted") return None wrap = lines[0].find('-----') < 0 # Only base64 pasted, add the CSR tags for parsing cert_req = load_certificate_request("\n".join(lines), wrap) if not cert_req: print("Invalid certificate request") return None cert = generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=False) passphrase = ask_passphrase() if not install and not file: print(encode_certificate(cert)) print(encode_private_key(private_key, passphrase=passphrase)) return None if install: install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=False) if file: write_file(f'{name}.pem', encode_certificate(cert)) write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) def generate_certificate_selfsign(name, install=False, file=False): private_key, key_type = generate_private_key() cert_req = generate_certificate_request(private_key, key_type, return_request=True) cert = generate_certificate(cert_req, cert_req, private_key, is_ca=False) passphrase = ask_passphrase() if not install and not file: print(encode_certificate(cert)) print(encode_private_key(private_key, passphrase=passphrase)) return None if install: install_certificate(name, cert, private_key=private_key, key_type=key_type, key_passphrase=passphrase, is_ca=False) if file: write_file(f'{name}.pem', encode_certificate(cert)) write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) def generate_certificate_revocation_list(ca_name, install=False, file=False): ca_dict = get_config_ca_certificate(ca_name) if not ca_dict: print(f"CA certificate or private key for '{ca_name}' not found") return None ca_cert = load_certificate(ca_dict['certificate']) if not ca_cert: print("Failed to load CA certificate, aborting") return None ca_private = ca_dict['private'] ca_private_passphrase = None if 'password_protected' in ca_private: ca_private_passphrase = ask_input('Enter CA private key passphrase:') ca_private_key = load_private_key(ca_private['key'], passphrase=ca_private_passphrase) if not ca_private_key: print("Failed to load CA private key, aborting") return None revoked_certs = get_config_revoked_certificates() to_revoke = [] for cert_dict in revoked_certs: if 'certificate' not in cert_dict: continue cert_data = cert_dict['certificate'] try: cert = load_certificate(cert_data) if cert.issuer == ca_cert.subject: to_revoke.append(cert.serial_number) except ValueError: continue if not to_revoke: print("No revoked certificates to add to the CRL") return None crl = create_certificate_revocation_list(ca_cert, ca_private_key, to_revoke) if not crl: print("Failed to create CRL") return None if not install and not file: print(encode_certificate(crl)) return None if install: install_crl(ca_name, crl) if file: write_file(f'{name}.crl', encode_certificate(crl)) def generate_ssh_keypair(name, install=False, file=False): private_key, key_type = generate_private_key() public_key = private_key.public_key() passphrase = ask_passphrase() if not install and not file: print(encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH')) print("") print(encode_private_key(private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase)) return None if install: install_ssh_key(name, public_key, private_key, passphrase) if file: write_file(f'{name}.pem', encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH')) write_file(f'{name}.key', encode_private_key(private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase)) def generate_dh_parameters(name, install=False, file=False): bits = ask_input('Enter DH parameters key size:', default=2048, numeric_only=True) print("Generating parameters...") dh_params = create_dh_parameters(bits) if not dh_params: print("Failed to create DH parameters") return None if not install and not file: print("DH Parameters:") print(encode_dh_parameters(dh_params)) if install: install_dh_parameters(name, dh_params) if file: write_file(f'{name}.pem', encode_dh_parameters(dh_params)) def generate_keypair(name, install=False, file=False): private_key, key_type = generate_private_key() public_key = private_key.public_key() passphrase = ask_passphrase() if not install and not file: print(encode_public_key(public_key)) print("") print(encode_private_key(private_key, passphrase=passphrase)) return None if install: install_keypair(name, key_type, private_key, public_key, passphrase) if file: write_file(f'{name}.pem', encode_public_key(public_key)) write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) def generate_openvpn_key(name, install=False, file=False): result = cmd('openvpn --genkey secret /dev/stdout | grep -o "^[^#]*"') if not result: print("Failed to generate OpenVPN key") return None if not install and not file: print(result) return None if install: key_lines = result.split("\n") key_data = "".join(key_lines[1:-1]) # Remove wrapper tags and line endings key_version = '1' version_search = re.search(r'BEGIN OpenVPN Static key V(\d+)', result) # Future-proofing (hopefully) if version_search: key_version = version_search[1] install_openvpn_key(name, key_data, key_version) if file: write_file(f'{name}.key', result) def generate_wireguard_key(interface=None, install=False): private_key = cmd('wg genkey') public_key = cmd('wg pubkey', input=private_key) if interface and install: install_wireguard_key(interface, private_key, public_key) else: print(f'Private key: {private_key}') print(f'Public key: {public_key}', end='\n\n') def generate_wireguard_psk(interface=None, peer=None, install=False): psk = cmd('wg genpsk') if interface and peer and install: install_wireguard_psk(interface, peer, psk) else: print(f'Pre-shared key: {psk}') # Import functions def import_ca_certificate(name, path=None, key_path=None): if path: if not os.path.exists(path): print(f'File not found: {path}') return cert = None with open(path) as f: cert_data = f.read() cert = load_certificate(cert_data, wrap_tags=False) if not cert: print(f'Invalid certificate: {path}') return install_certificate(name, cert, is_ca=True) if key_path: if not os.path.exists(key_path): print(f'File not found: {key_path}') return key = None passphrase = ask_input('Enter private key passphrase: ') or None with open(key_path) as f: key_data = f.read() key = load_private_key(key_data, passphrase=passphrase, wrap_tags=False) if not key: print(f'Invalid private key or passphrase: {path}') return install_certificate(name, private_key=key, is_ca=True) def import_certificate(name, path=None, key_path=None): if path: if not os.path.exists(path): print(f'File not found: {path}') return cert = None with open(path) as f: cert_data = f.read() cert = load_certificate(cert_data, wrap_tags=False) if not cert: print(f'Invalid certificate: {path}') return install_certificate(name, cert, is_ca=False) if key_path: if not os.path.exists(key_path): print(f'File not found: {key_path}') return key = None passphrase = ask_input('Enter private key passphrase: ') or None with open(key_path) as f: key_data = f.read() key = load_private_key(key_data, passphrase=passphrase, wrap_tags=False) if not key: print(f'Invalid private key or passphrase: {path}') return install_certificate(name, private_key=key, is_ca=False) def import_crl(name, path): if not os.path.exists(path): print(f'File not found: {path}') return crl = None with open(path) as f: crl_data = f.read() crl = load_crl(crl_data, wrap_tags=False) if not crl: print(f'Invalid certificate: {path}') return install_crl(name, crl) def import_dh_parameters(name, path): if not os.path.exists(path): print(f'File not found: {path}') return dh = None with open(path) as f: dh_data = f.read() dh = load_dh_parameters(dh_data, wrap_tags=False) if not dh: print(f'Invalid DH parameters: {path}') return install_dh_parameters(name, dh) def import_keypair(name, path=None, key_path=None): if path: if not os.path.exists(path): print(f'File not found: {path}') return key = None with open(path) as f: key_data = f.read() key = load_public_key(key_data, wrap_tags=False) if not key: print(f'Invalid public key: {path}') return install_keypair(name, None, public_key=key, prompt=False) if key_path: if not os.path.exists(key_path): print(f'File not found: {key_path}') return key = None passphrase = ask_input('Enter private key passphrase: ') or None with open(key_path) as f: key_data = f.read() key = load_private_key(key_data, passphrase=passphrase, wrap_tags=False) if not key: print(f'Invalid private key or passphrase: {path}') return install_keypair(name, None, private_key=key, prompt=False) def import_openvpn_secret(name, path): if not os.path.exists(path): print(f'File not found: {path}') return key_data = None key_version = '1' with open(path) as f: key_lines = f.read().split("\n") key_data = "".join(key_lines[1:-1]) # Remove wrapper tags and line endings version_search = re.search(r'BEGIN OpenVPN Static key V(\d+)', key_lines[0]) # Future-proofing (hopefully) if version_search: key_version = version_search[1] install_openvpn_key(name, key_data, key_version) # Show functions def show_certificate_authority(name=None, pem=False): headers = ['Name', 'Subject', 'Issuer CN', 'Issued', 'Expiry', 'Private Key', 'Parent'] data = [] certs = get_config_ca_certificate() if certs: for cert_name, cert_dict in certs.items(): if name and name != cert_name: continue if 'certificate' not in cert_dict: continue cert = load_certificate(cert_dict['certificate']) if name and pem: print(encode_certificate(cert)) return parent_ca_name = get_certificate_ca(cert, certs) cert_issuer_cn = cert.issuer.rfc4514_string().split(",")[0] if not parent_ca_name or parent_ca_name == cert_name: parent_ca_name = 'N/A' if not cert: continue have_private = 'Yes' if 'private' in cert_dict and 'key' in cert_dict['private'] else 'No' data.append([cert_name, cert.subject.rfc4514_string(), cert_issuer_cn, cert.not_valid_before, cert.not_valid_after, have_private, parent_ca_name]) print("Certificate Authorities:") print(tabulate.tabulate(data, headers)) def show_certificate(name=None, pem=False): headers = ['Name', 'Type', 'Subject CN', 'Issuer CN', 'Issued', 'Expiry', 'Revoked', 'Private Key', 'CA Present'] data = [] certs = get_config_certificate() if certs: ca_certs = get_config_ca_certificate() for cert_name, cert_dict in certs.items(): if name and name != cert_name: continue if 'certificate' not in cert_dict: continue cert = load_certificate(cert_dict['certificate']) if not cert: continue if name and pem: print(encode_certificate(cert)) return ca_name = get_certificate_ca(cert, ca_certs) cert_subject_cn = cert.subject.rfc4514_string().split(",")[0] cert_issuer_cn = cert.issuer.rfc4514_string().split(",")[0] cert_type = 'Unknown' ext = cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage) if ext and ExtendedKeyUsageOID.SERVER_AUTH in ext.value: cert_type = 'Server' elif ext and ExtendedKeyUsageOID.CLIENT_AUTH in ext.value: cert_type = 'Client' revoked = 'Yes' if 'revoke' in cert_dict else 'No' have_private = 'Yes' if 'private' in cert_dict and 'key' in cert_dict['private'] else 'No' have_ca = f'Yes ({ca_name})' if ca_name else 'No' data.append([ cert_name, cert_type, cert_subject_cn, cert_issuer_cn, cert.not_valid_before, cert.not_valid_after, revoked, have_private, have_ca]) print("Certificates:") print(tabulate.tabulate(data, headers)) def show_certificate_fingerprint(name, hash): cert = get_config_certificate(name=name) cert = load_certificate(cert['certificate']) print(get_certificate_fingerprint(cert, hash)) def show_crl(name=None, pem=False): headers = ['CA Name', 'Updated', 'Revokes'] data = [] certs = get_config_ca_certificate() if certs: for cert_name, cert_dict in certs.items(): if name and name != cert_name: continue if 'crl' not in cert_dict: continue crls = cert_dict['crl'] if isinstance(crls, str): crls = [crls] for crl_data in cert_dict['crl']: crl = load_crl(crl_data) if not crl: continue if name and pem: print(encode_certificate(crl)) continue certs = get_revoked_by_serial_numbers([revoked.serial_number for revoked in crl]) data.append([cert_name, crl.last_update, ", ".join(certs)]) if name and pem: return print("Certificate Revocation Lists:") print(tabulate.tabulate(data, headers)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--action', help='PKI action', required=True) # X509 parser.add_argument('--ca', help='Certificate Authority', required=False) parser.add_argument('--certificate', help='Certificate', required=False) parser.add_argument('--crl', help='Certificate Revocation List', required=False) parser.add_argument('--sign', help='Sign certificate with specified CA', required=False) parser.add_argument('--self-sign', help='Self-sign the certificate', action='store_true') parser.add_argument('--pem', help='Output using PEM encoding', action='store_true') parser.add_argument('--fingerprint', help='Show fingerprint and exit', action='store') # SSH parser.add_argument('--ssh', help='SSH Key', required=False) # DH parser.add_argument('--dh', help='DH Parameters', required=False) # Key pair parser.add_argument('--keypair', help='Key pair', required=False) # OpenVPN parser.add_argument('--openvpn', help='OpenVPN TLS key', required=False) # WireGuard parser.add_argument('--wireguard', help='Wireguard', action='store_true') group = parser.add_mutually_exclusive_group() group.add_argument('--key', help='Wireguard key pair', action='store_true', required=False) group.add_argument('--psk', help='Wireguard pre shared key', action='store_true', required=False) parser.add_argument('--interface', help='Install generated keys into running-config for named interface', action='store') parser.add_argument('--peer', help='Install generated keys into running-config for peer', action='store') # Global parser.add_argument('--file', help='Write generated keys into specified filename', action='store_true') parser.add_argument('--install', help='Install generated keys into running-config', action='store_true') parser.add_argument('--filename', help='Write certificate into specified filename', action='store') parser.add_argument('--key-filename', help='Write key into specified filename', action='store') args = parser.parse_args() try: if args.action == 'generate': if args.ca: if args.sign: generate_ca_certificate_sign(args.ca, args.sign, install=args.install, file=args.file) else: generate_ca_certificate(args.ca, install=args.install, file=args.file) elif args.certificate: if args.sign: generate_certificate_sign(args.certificate, args.sign, install=args.install, file=args.file) elif args.self_sign: generate_certificate_selfsign(args.certificate, install=args.install, file=args.file) else: generate_certificate_request(name=args.certificate, install=args.install, file=args.file) elif args.crl: generate_certificate_revocation_list(args.crl, install=args.install, file=args.file) elif args.ssh: generate_ssh_keypair(args.ssh, install=args.install, file=args.file) elif args.dh: generate_dh_parameters(args.dh, install=args.install, file=args.file) elif args.keypair: generate_keypair(args.keypair, install=args.install, file=args.file) elif args.openvpn: generate_openvpn_key(args.openvpn, install=args.install, file=args.file) elif args.wireguard: # WireGuard supports writing key directly into the CLI, but this # requires the vyos_libexec_dir environment variable to be set os.environ["vyos_libexec_dir"] = "/usr/libexec/vyos" if args.key: generate_wireguard_key(args.interface, install=args.install) if args.psk: generate_wireguard_psk(args.interface, peer=args.peer, install=args.install) elif args.action == 'import': if args.ca: import_ca_certificate(args.ca, path=args.filename, key_path=args.key_filename) elif args.certificate: import_certificate(args.certificate, path=args.filename, key_path=args.key_filename) elif args.crl: import_crl(args.crl, args.filename) elif args.dh: import_dh_parameters(args.dh, args.filename) elif args.keypair: import_keypair(args.keypair, path=args.filename, key_path=args.key_filename) elif args.openvpn: import_openvpn_secret(args.openvpn, args.filename) elif args.action == 'show': if args.ca: ca_name = None if args.ca == 'all' else args.ca if ca_name: if not conf.exists(['pki', 'ca', ca_name]): print(f'CA "{ca_name}" does not exist!') exit(1) show_certificate_authority(ca_name, args.pem) elif args.certificate: cert_name = None if args.certificate == 'all' else args.certificate if cert_name: if not conf.exists(['pki', 'certificate', cert_name]): print(f'Certificate "{cert_name}" does not exist!') exit(1) if args.fingerprint is None: show_certificate(None if args.certificate == 'all' else args.certificate, args.pem) else: show_certificate_fingerprint(args.certificate, args.fingerprint) elif args.crl: show_crl(None if args.crl == 'all' else args.crl, args.pem) else: show_certificate_authority() show_certificate() show_crl() except KeyboardInterrupt: print("Aborted") sys.exit(0) diff --git a/src/op_mode/show_openconnect_otp.py b/src/op_mode/show_openconnect_otp.py index 415a5f72c..3771fb385 100755 --- a/src/op_mode/show_openconnect_otp.py +++ b/src/op_mode/show_openconnect_otp.py @@ -1,109 +1,107 @@ #!/usr/bin/env python3 # Copyright 2017-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. import argparse import os +from base64 import b32encode from vyos.config import Config -from vyos.xml import defaults -from vyos.configdict import dict_merge +from vyos.utils.dict import dict_search_args from vyos.utils.process import popen -from base64 import b32encode otp_file = '/run/ocserv/users.oath' def check_uname_otp(username): """ Check if "username" exists and have an OTP key """ config = Config() base_key = ['vpn', 'openconnect', 'authentication', 'local-users', 'username', username, 'otp', 'key'] if not config.exists(base_key): - return None + return False return True def get_otp_ocserv(username): config = Config() base = ['vpn', 'openconnect'] if not config.exists(base): return None - ocserv = config.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True) - # We have gathered the dict representation of the CLI, but there are default - # options which we need to update into the dictionary retrived. - default_values = defaults(base) - ocserv = dict_merge(default_values, ocserv) - # workaround a "know limitation" - https://vyos.dev/T2665 - del ocserv['authentication']['local_users']['username']['otp'] - if not ocserv["authentication"]["local_users"]["username"]: + + ocserv = config.get_config_dict(base, key_mangling=('-', '_'), + get_first_key=True, + with_recursive_defaults=True) + + user_path = ['authentication', 'local_users', 'username'] + users = dict_search_args(ocserv, *user_path) + + if users is None: return None - default_ocserv_usr_values = default_values['authentication']['local_users']['username']['otp'] - for user, params in ocserv['authentication']['local_users']['username'].items(): - # Not every configuration requires OTP settings - if ocserv['authentication']['local_users']['username'][user].get('otp'): - ocserv['authentication']['local_users']['username'][user]['otp'] = dict_merge(default_ocserv_usr_values, ocserv['authentication']['local_users']['username'][user]['otp']) - result = ocserv['authentication']['local_users']['username'][username] + + # function is called conditionally, if check_uname_otp true, so username + # exists + result = users[username] + return result def display_otp_ocserv(username, params, info): hostname = os.uname()[1] key_hex = params['otp']['key'] otp_length = params['otp']['otp_length'] interval = params['otp']['interval'] token_type = params['otp']['token_type'] if token_type == 'hotp-time': token_type_acrn = 'totp' key_base32 = b32encode(bytes.fromhex(key_hex)).decode() otp_url = ''.join(["otpauth://",token_type_acrn,"/",username,"@",hostname,"?secret=",key_base32,"&digits=",otp_length,"&period=",interval]) qrcode,err = popen('qrencode -t ansiutf8', input=otp_url) if info == 'full': print("# You can share it with the user, he just needs to scan the QR in his OTP app") print("# username: ", username) print("# OTP KEY: ", key_base32) print("# OTP URL: ", otp_url) print(qrcode) print('# To add this OTP key to configuration, run the following commands:') print(f"set vpn openconnect authentication local-users username {username} otp key '{key_hex}'") if interval != "30": print(f"set vpn openconnect authentication local-users username {username} otp interval '{interval}'") if otp_length != "6": print(f"set vpn openconnect authentication local-users username {username} otp otp-length '{otp_length}'") elif info == 'key-hex': print("# OTP key in hexadecimal: ") print(key_hex) elif info == 'key-b32': print("# OTP key in Base32: ") print(key_base32) elif info == 'qrcode': print(f"# QR code for OpenConnect user '{username}'") print(qrcode) elif info == 'uri': print(f"# URI for OpenConnect user '{username}'") print(otp_url) if __name__ == '__main__': parser = argparse.ArgumentParser(add_help=False, description='Show OTP authentication information for selected user') parser.add_argument('--user', action="store", type=str, default='', help='Username') parser.add_argument('--info', action="store", type=str, default='full', help='Wich information to display') args = parser.parse_args() - check_otp = check_uname_otp(args.user) - if check_otp: + if check_uname_otp(args.user): user_otp_params = get_otp_ocserv(args.user) display_otp_ocserv(args.user, user_otp_params, args.info) else: print(f'There is no such user ("{args.user}") with an OTP key configured') diff --git a/src/tests/test_initial_setup.py b/src/tests/test_initial_setup.py index cb843ff09..ba50d06cc 100644 --- a/src/tests/test_initial_setup.py +++ b/src/tests/test_initial_setup.py @@ -1,102 +1,104 @@ #!/usr/bin/env python3 # # Copyright (C) 2018 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import os import tempfile import unittest import vyos.configtree import vyos.initialsetup as vis from unittest import TestCase -from vyos import xml +from vyos.xml_ref import definition +from vyos.xml_ref.pkg_cache.vyos_1x_cache import reference class TestInitialSetup(TestCase): def setUp(self): with open('tests/data/config.boot.default', 'r') as f: config_string = f.read() self.config = vyos.configtree.ConfigTree(config_string) - self.xml = xml.load_configuration() + self.xml = definition.Xml() + self.xml.define(reference) def test_set_user_password(self): vis.set_user_password(self.config, 'vyos', 'vyosvyos') # Old password hash from the default config old_pw = '$6$QxPS.uk6mfo$9QBSo8u1FkH16gMyAVhus6fU3LOzvLR9Z9.82m3tiHFAxTtIkhaZSWssSgzt4v4dGAL8rhVQxTg0oAG9/q11h/' new_pw = self.config.return_value(["system", "login", "user", "vyos", "authentication", "encrypted-password"]) # Just check it changed the hash, don't try to check if hash is good self.assertNotEqual(old_pw, new_pw) def test_disable_user_password(self): vis.disable_user_password(self.config, 'vyos') new_pw = self.config.return_value(["system", "login", "user", "vyos", "authentication", "encrypted-password"]) self.assertEqual(new_pw, '!') def test_set_ssh_key_with_name(self): test_ssh_key = " ssh-rsa fakedata vyos@vyos " vis.set_user_ssh_key(self.config, 'vyos', test_ssh_key) key_type = self.config.return_value(["system", "login", "user", "vyos", "authentication", "public-keys", "vyos@vyos", "type"]) key_data = self.config.return_value(["system", "login", "user", "vyos", "authentication", "public-keys", "vyos@vyos", "key"]) self.assertEqual(key_type, 'ssh-rsa') self.assertEqual(key_data, 'fakedata') self.assertTrue(self.xml.is_tag(["system", "login", "user", "vyos", "authentication", "public-keys"])) def test_set_ssh_key_without_name(self): # If key file doesn't include a name, the function will use user name for the key name test_ssh_key = " ssh-rsa fakedata " vis.set_user_ssh_key(self.config, 'vyos', test_ssh_key) key_type = self.config.return_value(["system", "login", "user", "vyos", "authentication", "public-keys", "vyos", "type"]) key_data = self.config.return_value(["system", "login", "user", "vyos", "authentication", "public-keys", "vyos", "key"]) self.assertEqual(key_type, 'ssh-rsa') self.assertEqual(key_data, 'fakedata') self.assertTrue(self.xml.is_tag(["system", "login", "user", "vyos", "authentication", "public-keys"])) def test_create_user(self): vis.create_user(self.config, 'jrandomhacker', password='qwerty', key=" ssh-rsa fakedata jrandomhacker@foovax ") self.assertTrue(self.config.exists(["system", "login", "user", "jrandomhacker"])) self.assertTrue(self.config.exists(["system", "login", "user", "jrandomhacker", "authentication", "public-keys", "jrandomhacker@foovax"])) self.assertTrue(self.config.exists(["system", "login", "user", "jrandomhacker", "authentication", "encrypted-password"])) self.assertEqual(self.config.return_value(["system", "login", "user", "jrandomhacker", "level"]), "admin") def test_set_hostname(self): vis.set_host_name(self.config, "vyos-test") self.assertEqual(self.config.return_value(["system", "host-name"]), "vyos-test") def test_set_name_servers(self): vis.set_name_servers(self.config, ["192.0.2.10", "203.0.113.20"]) servers = self.config.return_values(["system", "name-server"]) self.assertIn("192.0.2.10", servers) self.assertIn("203.0.113.20", servers) def test_set_gateway(self): vis.set_default_gateway(self.config, '192.0.2.1') self.assertTrue(self.config.exists(['protocols', 'static', 'route', '0.0.0.0/0', 'next-hop', '192.0.2.1'])) self.assertTrue(self.xml.is_tag(['protocols', 'static', 'multicast', 'route', '0.0.0.0/0', 'next-hop'])) self.assertTrue(self.xml.is_tag(['protocols', 'static', 'multicast', 'route'])) if __name__ == "__main__": unittest.main()