diff --git a/python/vyos/progressbar.py b/python/vyos/progressbar.py new file mode 100644 index 000000000..1793c445b --- /dev/null +++ b/python/vyos/progressbar.py @@ -0,0 +1,70 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import math +import os +import signal +import subprocess +import sys + +from vyos.utils.io import print_error + +class Progressbar: + def __init__(self, step=None): + self.total = 0.0 + self.step = step + def __enter__(self): + # Recalculate terminal width with every window resize. + signal.signal(signal.SIGWINCH, lambda signum, frame: self._update_cols()) + # Disable line wrapping to prevent the staircase effect. + subprocess.run(['tput', 'rmam'], check=False) + self._update_cols() + # Print an empty progressbar with entry. + self.progress(0, 1) + return self + def __exit__(self, exc_type, kexc_val, exc_tb): + # Revert to the default SIGWINCH handler (ie nothing). + signal.signal(signal.SIGWINCH, signal.SIG_DFL) + # Reenable line wrapping. + subprocess.run(['tput', 'smam'], check=False) + def _update_cols(self): + # `os.get_terminal_size()' is fast enough for our purposes. + self.col = max(os.get_terminal_size().columns - 15, 20) + def increment(self): + """ + Stateful progressbar taking the step fraction at init and no input at + callback (for FTP) + """ + if self.step: + if self.total < 1.0: + self.total += self.step + if self.total >= 1.0: + self.total = 1.0 + # Ignore superfluous calls caused by fuzzy FTP size calculations. + self.step = None + self.progress(self.total, 1.0) + def progress(self, done, total): + """ + Stateless progressbar taking no input at init and current progress with + final size at callback (for SSH) + """ + if done <= total: + length = math.ceil(self.col * done / total) + percentage = str(math.ceil(100 * done / total)).rjust(3) + # Carriage return at the end will make sure the line will get overwritten. + print_error(f'[{length * "#"}{(self.col - length) * "_"}] {percentage}%', end='\r') + # Print a newline to make sure the full progressbar doesn't get overwritten by the next line. + if done == total: + print_error() diff --git a/python/vyos/remote.py b/python/vyos/remote.py index cf731c881..1ca8a9530 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -1,369 +1,371 @@ # Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. import os import shutil import socket import ssl import stat import sys import tempfile import urllib.parse from ftplib import FTP from ftplib import FTP_TLS from paramiko import SSHClient, SSHException from paramiko import MissingHostKeyPolicy from requests import Session from requests.adapters import HTTPAdapter from requests.packages.urllib3 import PoolManager +from vyos.progressbar import Progressbar from vyos.utils.io import ask_yes_no -from vyos.utils.io import make_incremental_progressbar -from vyos.utils.io import make_progressbar from vyos.utils.io import print_error from vyos.utils.misc import begin from vyos.utils.process import cmd from vyos.version import get_version CHUNK_SIZE = 8192 class InteractivePolicy(MissingHostKeyPolicy): """ Paramiko policy for interactively querying the user on whether to proceed with SSH connections to unknown hosts. """ def missing_host_key(self, client, hostname, key): print_error(f"Host '{hostname}' not found in known hosts.") print_error('Fingerprint: ' + key.get_fingerprint().hex()) if sys.stdout.isatty() and ask_yes_no('Do you wish to continue?'): if client._host_keys_filename\ and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): client._host_keys.add(hostname, key.get_name(), key) client.save_host_keys(client._host_keys_filename) else: raise SSHException(f"Cannot connect to unknown host '{hostname}'.") class SourceAdapter(HTTPAdapter): """ urllib3 transport adapter for setting source addresses per session. """ def __init__(self, source_pair, *args, **kwargs): # A source pair is a tuple of a source host string and source port respectively. # Supply '' and 0 respectively for default values. self._source_pair = source_pair super(SourceAdapter, self).__init__(*args, **kwargs) def init_poolmanager(self, connections, maxsize, block=False): self.poolmanager = PoolManager( num_pools=connections, maxsize=maxsize, block=block, source_address=self._source_pair) def check_storage(path, size): """ Check whether `path` has enough storage space for a transfer of `size` bytes. """ path = os.path.abspath(os.path.expanduser(path)) directory = path if os.path.isdir(path) else (os.path.dirname(os.path.expanduser(path)) or os.getcwd()) # `size` can be None or 0 to indicate unknown size. if not size: print_error('Warning: Cannot determine size of remote file. Bravely continuing regardless.') return if size < 1024 * 1024: print_error(f'The file is {size / 1024.0:.3f} KiB.') else: print_error(f'The file is {size / (1024.0 * 1024.0):.3f} MiB.') # Will throw `FileNotFoundError' if `directory' is absent. if size > shutil.disk_usage(directory).free: raise OSError(f'Not enough disk space available in "{directory}".') class FtpC: def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0, timeout=10): self.secure = url.scheme == 'ftps' self.hostname = url.hostname self.path = url.path self.username = url.username or os.getenv('REMOTE_USERNAME', 'anonymous') self.password = url.password or os.getenv('REMOTE_PASSWORD', '') self.port = url.port or 21 self.source = (source_host, source_port) self.progressbar = progressbar self.check_space = check_space self.timeout = timeout def _establish(self): if self.secure: return FTP_TLS(source_address=self.source, context=ssl.create_default_context(), timeout=self.timeout) else: return FTP(source_address=self.source, timeout=self.timeout) def download(self, location: str): # Open the file upfront before establishing connection. with open(location, 'wb') as f, self._establish() as conn: conn.connect(self.hostname, self.port) conn.login(self.username, self.password) # Set secure connection over TLS. if self.secure: conn.prot_p() # Almost all FTP servers support the `SIZE' command. + size = conn.size(self.path) if self.check_space: - check_storage(path, conn.size(self.path)) + check_storage(path, size) # No progressbar if we can't determine the size or if the file is too small. if self.progressbar and size and size > CHUNK_SIZE: - progress = make_incremental_progressbar(CHUNK_SIZE / size) - next(progress) - callback = lambda block: begin(f.write(block), next(progress)) + with Progressbar(CHUNK_SIZE / size) as p: + callback = lambda block: begin(f.write(block), p.increment()) + conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE) else: - callback = f.write - conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE) + conn.retrbinary('RETR ' + self.path, f.write, CHUNK_SIZE) def upload(self, location: str): size = os.path.getsize(location) with open(location, 'rb') as f, self._establish() as conn: conn.connect(self.hostname, self.port) conn.login(self.username, self.password) if self.secure: conn.prot_p() if self.progressbar and size and size > CHUNK_SIZE: - progress = make_incremental_progressbar(CHUNK_SIZE / size) - next(progress) - callback = lambda block: next(progress) + with Progressbar(CHUNK_SIZE / size) as p: + conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, lambda block: p.increment()) else: - callback = None - conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, callback) + conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE) class SshC: known_hosts = os.path.expanduser('~/.ssh/known_hosts') def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0, timeout=10.0): self.hostname = url.hostname self.path = url.path self.username = url.username or os.getenv('REMOTE_USERNAME') self.password = url.password or os.getenv('REMOTE_PASSWORD') self.port = url.port or 22 self.source = (source_host, source_port) self.progressbar = progressbar self.check_space = check_space self.timeout = timeout def _establish(self): ssh = SSHClient() ssh.load_system_host_keys() # Try to load from a user-local known hosts file if one exists. if os.path.exists(self.known_hosts): ssh.load_host_keys(self.known_hosts) ssh.set_missing_host_key_policy(InteractivePolicy()) # `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family # for us on dual-stack systems. sock = socket.create_connection((self.hostname, self.port), self.timeout, self.source) ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock) return ssh def download(self, location: str): - callback = make_progressbar() if self.progressbar else None with self._establish() as ssh, ssh.open_sftp() as sftp: if self.check_space: check_storage(location, sftp.stat(self.path).st_size) - sftp.get(self.path, location, callback=callback) + if self.progressbar: + with Progressbar() as p: + sftp.get(self.path, location, callback=p.progress) + else: + sftp.get(self.path, location) def upload(self, location: str): - callback = make_progressbar() if self.progressbar else None with self._establish() as ssh, ssh.open_sftp() as sftp: try: # If the remote path is a directory, use the original filename. if stat.S_ISDIR(sftp.stat(self.path).st_mode): path = os.path.join(self.path, os.path.basename(location)) # A file exists at this destination. We're simply going to clobber it. else: path = self.path # This path doesn't point at any existing file. We can freely use this filename. except IOError: path = self.path finally: - sftp.put(location, path, callback=callback) + if self.progressbar: + with Progressbar() as p: + sftp.put(location, path, callback=p.progress) + else: + sftp.put(location, path) class HttpC: def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0, timeout=10.0): self.urlstring = urllib.parse.urlunsplit(url) self.progressbar = progressbar self.check_space = check_space self.source_pair = (source_host, source_port) self.username = url.username or os.getenv('REMOTE_USERNAME') self.password = url.password or os.getenv('REMOTE_PASSWORD') self.timeout = timeout def _establish(self): session = Session() session.mount(self.urlstring, SourceAdapter(self.source_pair)) session.headers.update({'User-Agent': 'VyOS/' + get_version()}) if self.username: session.auth = self.username, self.password return session def download(self, location: str): with self._establish() as s: # We ask for uncompressed downloads so that we don't have to deal with decoding. # Not only would it potentially mess up with the progress bar but # `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding. s.headers.update({'Accept-Encoding': 'identity'}) with s.head(self.urlstring, allow_redirects=True, timeout=self.timeout) as r: # Abort early if the destination is inaccessible. print('pre-3') r.raise_for_status() # If the request got redirected, keep the last URL we ended up with. final_urlstring = r.url if r.history and self.progressbar: print_error('Redirecting to ' + final_urlstring) # Check for the prospective file size. try: size = int(r.headers['Content-Length']) # In case the server does not supply the header. except KeyError: size = None if self.check_space: check_storage(location, size) with s.get(final_urlstring, stream=True, timeout=self.timeout) as r, open(location, 'wb') as f: if self.progressbar and size: - progress = make_incremental_progressbar(CHUNK_SIZE / size) - next(progress) - for chunk in iter(lambda: begin(next(progress), r.raw.read(CHUNK_SIZE)), b''): - f.write(chunk) + with Progressbar(CHUNK_SIZE / size) as p: + for chunk in iter(lambda: begin(p.increment(), r.raw.read(CHUNK_SIZE)), b''): + f.write(chunk) else: # We'll try to stream the download directly with `copyfileobj()` so that large # files (like entire VyOS images) don't occupy much memory. shutil.copyfileobj(r.raw, f) def upload(self, location: str): # Does not yet support progressbars. with self._establish() as s, open(location, 'rb') as f: s.post(self.urlstring, data=f, allow_redirects=True, timeout=self.timeout) class TftpC: # We simply allow `curl` to take over because # 1. TFTP is rather simple. # 2. Since there's no concept authentication, we don't need to deal with keys/passwords. # 3. It would be a waste to import, audit and maintain a third-party library for TFTP. # 4. I'd rather not implement the entire protocol here, no matter how simple it is. def __init__(self, url, progressbar=False, check_space=False, source_host=None, source_port=0, timeout=10): source_option = f'--interface {source_host} --local-port {source_port}' if source_host else '' progress_flag = '--progress-bar' if progressbar else '-s' self.command = f'curl {source_option} {progress_flag} --connect-timeout {timeout}' self.urlstring = urllib.parse.urlunsplit(url) def download(self, location: str): with open(location, 'wb') as f: f.write(cmd(f'{self.command} "{self.urlstring}"').encode()) def upload(self, location: str): with open(location, 'rb') as f: cmd(f'{self.command} -T - "{self.urlstring}"', input=f.read()) def urlc(urlstring, *args, **kwargs): """ Dynamically dispatch the appropriate protocol class. """ url_classes = {'http': HttpC, 'https': HttpC, 'ftp': FtpC, 'ftps': FtpC, \ 'sftp': SshC, 'ssh': SshC, 'scp': SshC, 'tftp': TftpC} url = urllib.parse.urlsplit(urlstring) try: return url_classes[url.scheme](url, *args, **kwargs) except KeyError: raise ValueError(f'Unsupported URL scheme: "{url.scheme}"') def download(local_path, urlstring, *args, **kwargs): try: urlc(urlstring, *args, **kwargs).download(local_path) except Exception as err: print_error(f'Unable to download "{urlstring}": {err}') def upload(local_path, urlstring, *args, **kwargs): try: urlc(urlstring, *args, **kwargs).upload(local_path) except Exception as err: print_error(f'Unable to upload "{urlstring}": {err}') def get_remote_config(urlstring, source_host='', source_port=0): """ Quietly download a file and return it as a string. """ temp = tempfile.NamedTemporaryFile(delete=False).name try: download(temp, urlstring, False, False, source_host, source_port) with open(temp, 'r') as f: return f.read() finally: os.remove(temp) def friendly_download(local_path, urlstring, source_host='', source_port=0): """ Download with a progress bar, reassuring messages and free space checks. """ try: print_error('Downloading...') download(local_path, urlstring, True, True, source_host, source_port) except KeyboardInterrupt: print_error('\nDownload aborted by user.') sys.exit(1) except: import traceback print_error(f'Failed to download {urlstring}.') # There are a myriad different reasons a download could fail. # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...) # We omit the scary stack trace but print the error nevertheless. exc_type, exc_value, exc_traceback = sys.exc_info() traceback.print_exception(exc_type, exc_value, None, 0, None, False) sys.exit(1) else: print_error('Download complete.') sys.exit(0) diff --git a/python/vyos/utils/io.py b/python/vyos/utils/io.py index 843494855..5fffa62f8 100644 --- a/python/vyos/utils/io.py +++ b/python/vyos/utils/io.py @@ -1,103 +1,64 @@ # Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. def print_error(str='', end='\n'): """ Print `str` to stderr, terminated with `end`. Used for warnings and out-of-band messages to avoid mangling precious stdout output. """ import sys sys.stderr.write(str) sys.stderr.write(end) sys.stderr.flush() -def make_progressbar(): - """ - Make a procedure that takes two arguments `done` and `total` and prints a - progressbar based on the ratio thereof, whose length is determined by the - width of the terminal. - """ - import shutil, math - col, _ = shutil.get_terminal_size() - col = max(col - 15, 20) - def print_progressbar(done, total): - if done <= total: - increment = total / col - length = math.ceil(done / increment) - percentage = str(math.ceil(100 * done / total)).rjust(3) - print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') - # Print a newline so that the subsequent prints don't overwrite the full bar. - if done == total: - print_error() - return print_progressbar - -def make_incremental_progressbar(increment: float): - """ - Make a generator that displays a progressbar that grows monotonically with - every iteration. - First call displays it at 0% and every subsequent iteration displays it - at `increment` increments where 0.0 < `increment` < 1.0. - Intended for FTP and HTTP transfers with stateless callbacks. - """ - print_progressbar = make_progressbar() - total = 0.0 - while total < 1.0: - print_progressbar(total, 1.0) - yield - total += increment - print_progressbar(1, 1) - # Ignore further calls. - while True: - yield - def ask_input(question, default='', numeric_only=False, valid_responses=[]): question_out = question if default: question_out += f' (Default: {default})' response = '' while True: response = input(question_out + ' ').strip() if not response and default: return default if numeric_only: if not response.isnumeric(): print("Invalid value, try again.") continue response = int(response) if valid_responses and response not in valid_responses: print("Invalid value, try again.") continue break return response def ask_yes_no(question, default=False) -> bool: """Ask a yes/no question via input() and return their answer.""" from sys import stdout default_msg = "[Y/n]" if default else "[y/N]" while True: try: stdout.write("%s %s " % (question, default_msg)) c = input().lower() if c == '': return default elif c in ("y", "ye", "yes"): return True elif c in ("n", "no"): return False else: stdout.write("Please respond with yes/y or no/n\n") except EOFError: stdout.write("\nPlease respond with yes/y or no/n\n")