예제 #1
0
 def test_shell_error_bad_command(self, _, password):
     sshcp = Sshcp(host=self.host,
                   port=self.port,
                   user=self.user,
                   password=password)
     with self.assertRaises(SshcpError) as ctx:
         sshcp.shell("./some_bad_command.sh".format(self.local_dir))
     self.assertTrue("./some_bad_command.sh" in str(ctx.exception))
예제 #2
0
 def test_shell_error_bad_host(self, _, password):
     sshcp = Sshcp(host="badhost",
                   port=self.port,
                   user=self.user,
                   password=password)
     with self.assertRaises(SshcpError) as ctx:
         sshcp.shell("cd {}; pwd".format(self.local_dir))
     self.assertTrue("Bad hostname" in str(ctx.exception))
예제 #3
0
 def test_shell(self, _, password):
     sshcp = Sshcp(host=self.host,
                   port=self.port,
                   user=self.user,
                   password=password)
     out = sshcp.shell("cd {}; pwd".format(self.local_dir))
     out_str = out.decode().strip()
     self.assertEqual(self.local_dir, out_str)
예제 #4
0
 def test_shell_error_bad_password(self):
     sshcp = Sshcp(host=self.host,
                   port=self.port,
                   user=self.user,
                   password="******")
     with self.assertRaises(SshcpError) as ctx:
         sshcp.shell("cd {}; pwd".format(self.local_dir))
     self.assertEqual("Incorrect password", str(ctx.exception))
예제 #5
0
 def test_shell_error_bad_port(self, _, password):
     sshcp = Sshcp(host=self.host,
                   port=6666,
                   user=self.user,
                   password=password)
     with self.assertRaises(SshcpError) as ctx:
         sshcp.shell("cd {}; pwd".format(self.local_dir))
     self.assertTrue("Connection refused by server" in str(ctx.exception))
예제 #6
0
 def test_copy_error_bad_password(self):
     sshcp = Sshcp(host=self.host,
                   port=self.port,
                   user=self.user,
                   password="******")
     with self.assertRaises(SshcpError) as ctx:
         sshcp.copy(local_path=self.local_file,
                    remote_path=self.remote_file)
     self.assertEqual("Incorrect password", str(ctx.exception))
예제 #7
0
 def test_copy_error_bad_host(self, _, password):
     sshcp = Sshcp(host="badhost",
                   port=self.port,
                   user=self.user,
                   password=password)
     with self.assertRaises(SshcpError) as ctx:
         sshcp.copy(local_path=self.local_file,
                    remote_path=self.remote_file)
     self.assertTrue("Connection refused by server" in str(ctx.exception))
예제 #8
0
    def test_copy(self, _, password):
        self.assertFalse(os.path.exists(self.remote_file))
        sshcp = Sshcp(host=self.host,
                      port=self.port,
                      user=self.user,
                      password=password)
        sshcp.copy(local_path=self.local_file, remote_path=self.remote_file)

        self.assertTrue(filecmp.cmp(self.local_file, self.remote_file))
예제 #9
0
 def __init__(self, remote_address: str, remote_username: str,
              remote_password: Optional[str], remote_port: int,
              remote_path: str, file_name: str):
     super().__init__(name=self.__class__.__name__)
     self.__remote_path = remote_path
     self.__file_name = file_name
     self.__ssh = Sshcp(host=remote_address,
                        port=remote_port,
                        user=remote_username,
                        password=remote_password)
예제 #10
0
    def test_copy_error_missing_remote_dir(self, _, password):
        remote_file = os.path.join(self.remote_dir, "nodir", "file2.txt")
        self.assertFalse(os.path.exists(remote_file))

        sshcp = Sshcp(host=self.host,
                      port=self.port,
                      user=self.user,
                      password=password)
        with self.assertRaises(SshcpError) as ctx:
            sshcp.copy(local_path=self.local_file, remote_path=remote_file)
        self.assertTrue("No such file or directory" in str(ctx.exception))
예제 #11
0
 def __init__(self, remote_address: str, remote_username: str,
              remote_password: Optional[str], remote_port: int,
              remote_path_to_scan: str, local_path_to_scan_script: str,
              remote_path_to_scan_script: str):
     self.logger = logging.getLogger("RemoteScanner")
     self.__remote_path_to_scan = remote_path_to_scan
     self.__local_path_to_scan_script = local_path_to_scan_script
     self.__remote_path_to_scan_script = remote_path_to_scan_script
     self.__ssh = Sshcp(host=remote_address,
                        port=remote_port,
                        user=remote_username,
                        password=remote_password)
     self.__first_run = True
예제 #12
0
    def __init__(self, remote_address: str, remote_username: str,
                 remote_password: Optional[str], remote_port: int,
                 remote_path_to_scan: str, local_path_to_scan_script: str,
                 remote_path_to_scan_script: str):
        self.logger = logging.getLogger("RemoteScanner")
        self.__remote_path_to_scan = remote_path_to_scan
        self.__local_path_to_scan_script = local_path_to_scan_script
        self.__remote_path_to_scan_script = remote_path_to_scan_script
        self.__ssh = Sshcp(host=remote_address,
                           port=remote_port,
                           user=remote_username,
                           password=remote_password)
        self.__first_run = True

        # Append scan script name to remote path if not there already
        script_name = os.path.basename(self.__local_path_to_scan_script)
        if os.path.basename(self.__remote_path_to_scan_script) != script_name:
            self.__remote_path_to_scan_script = os.path.join(
                self.__remote_path_to_scan_script, script_name)
예제 #13
0
class DeleteRemoteProcess(AppOneShotProcess):
    def __init__(self, remote_address: str, remote_username: str,
                 remote_password: Optional[str], remote_port: int,
                 remote_path: str, file_name: str):
        super().__init__(name=self.__class__.__name__)
        self.__remote_path = remote_path
        self.__file_name = file_name
        self.__ssh = Sshcp(host=remote_address,
                           port=remote_port,
                           user=remote_username,
                           password=remote_password)

    def run_once(self):
        self.__ssh.set_base_logger(self.logger)
        file_path = os.path.join(self.__remote_path, self.__file_name)
        self.logger.debug("Deleting remote file {}".format(self.__file_name))
        try:
            out = self.__ssh.shell("rm -rf '{}'".format(file_path))
            self.logger.debug("Remote delete output: {}".format(out.decode()))
        except SshcpError:
            self.logger.exception("Exception while deleting remote file")
예제 #14
0
    def test_shell_with_escape_characters(self, _, password):
        sshcp = Sshcp(host=self.host,
                      port=self.port,
                      user=self.user,
                      password=password)

        # single quotes
        _dir = os.path.join(self.remote_dir, "a a")
        out = sshcp.shell("mkdir '{}' && cd '{}' && pwd".format(_dir, _dir))
        out_str = out.decode().strip()
        self.assertEqual(_dir, out_str)

        # double quotes
        _dir = os.path.join(self.remote_dir, "a b")
        out = sshcp.shell('mkdir "{}" && cd "{}" && pwd'.format(_dir, _dir))
        out_str = out.decode().strip()
        self.assertEqual(_dir, out_str)

        # single and double quotes - error out
        _dir = os.path.join(self.remote_dir, "a b")
        with self.assertRaises(ValueError):
            sshcp.shell('mkdir "{}" && cd \'{}\' && pwd'.format(_dir, _dir))
예제 #15
0
class RemoteScanner(IScanner):
    """
    Scanner implementation to scan the remote filesystem
    """
    RETRY_COUNT = 5

    def __init__(self, remote_address: str, remote_username: str,
                 remote_password: Optional[str], remote_port: int,
                 remote_path_to_scan: str, local_path_to_scan_script: str,
                 remote_path_to_scan_script: str):
        self.logger = logging.getLogger("RemoteScanner")
        self.__remote_path_to_scan = remote_path_to_scan
        self.__local_path_to_scan_script = local_path_to_scan_script
        self.__remote_path_to_scan_script = remote_path_to_scan_script
        self.__ssh = Sshcp(host=remote_address,
                           port=remote_port,
                           user=remote_username,
                           password=remote_password)
        self.__first_run = True

    @overrides(IScanner)
    def set_base_logger(self, base_logger: logging.Logger):
        self.logger = base_logger.getChild("RemoteScanner")
        self.__ssh.set_base_logger(self.logger)

    @overrides(IScanner)
    def scan(self) -> List[SystemFile]:
        if self.__first_run:
            self._install_scanfs()
            self.__first_run = False

        retries = 0
        out = None
        while out is None:
            try:
                out = self.__ssh.shell("'{}' '{}'".format(
                    self.__remote_path_to_scan_script,
                    self.__remote_path_to_scan))
            except SshcpError as e:
                # Suppress specific errors and retry a fixed number of times
                # Otherwise raise a fatal AppError
                if RemoteScanner.__suppress_error(
                        e) and retries < RemoteScanner.RETRY_COUNT:
                    self.logger.warning(
                        "Retrying remote scan after error: {}".format(str(e)))
                    out = None
                    retries += 1
                else:
                    self.logger.exception("Caught an SshError")
                    raise AppError(Localization.Error.REMOTE_SERVER_SCAN)

        try:
            remote_files = pickle.loads(out)
        except pickle.UnpicklingError as err:
            self.logger.error("Unpickling error: {}\n{}".format(str(err), out))
            raise AppError(Localization.Error.REMOTE_SERVER_SCAN)
        return remote_files

    def _install_scanfs(self):
        self.logger.info("Installing local:{} to remote:{}".format(
            self.__local_path_to_scan_script,
            self.__remote_path_to_scan_script))
        if not os.path.isfile(self.__local_path_to_scan_script):
            raise RemoteScannerError(
                "Failed to find scanfs executable at {}".format(
                    self.__local_path_to_scan_script))
        try:
            self.__ssh.copy(local_path=self.__local_path_to_scan_script,
                            remote_path=self.__remote_path_to_scan_script)
        except SshcpError:
            self.logger.exception("Caught scp exception")
            raise AppError(Localization.Error.REMOTE_SERVER_INSTALL)

    @staticmethod
    def __suppress_error(error: SshcpError) -> bool:
        error_str = str(error).lower()
        errors_to_suppress = [
            "text file busy", "ssh_exchange_identification",
            "cannot create temporary directory", "connection timed out"
        ]
        return any(e in error_str for e in errors_to_suppress)
예제 #16
0
class RemoteScanner(IScanner):
    """
    Scanner implementation to scan the remote filesystem
    """
    def __init__(self, remote_address: str, remote_username: str,
                 remote_password: Optional[str], remote_port: int,
                 remote_path_to_scan: str, local_path_to_scan_script: str,
                 remote_path_to_scan_script: str):
        self.logger = logging.getLogger("RemoteScanner")
        self.__remote_path_to_scan = remote_path_to_scan
        self.__local_path_to_scan_script = local_path_to_scan_script
        self.__remote_path_to_scan_script = remote_path_to_scan_script
        self.__ssh = Sshcp(host=remote_address,
                           port=remote_port,
                           user=remote_username,
                           password=remote_password)
        self.__first_run = True

        # Append scan script name to remote path if not there already
        script_name = os.path.basename(self.__local_path_to_scan_script)
        if os.path.basename(self.__remote_path_to_scan_script) != script_name:
            self.__remote_path_to_scan_script = os.path.join(
                self.__remote_path_to_scan_script, script_name)

    @overrides(IScanner)
    def set_base_logger(self, base_logger: logging.Logger):
        self.logger = base_logger.getChild("RemoteScanner")
        self.__ssh.set_base_logger(self.logger)

    @overrides(IScanner)
    def scan(self) -> List[SystemFile]:
        if self.__first_run:
            self._install_scanfs()

        try:
            out = self.__ssh.shell("'{}' '{}'".format(
                self.__remote_path_to_scan_script, self.__remote_path_to_scan))
        except SshcpError as e:
            self.logger.warning("Caught an SshcpError: {}".format(str(e)))
            recoverable = True
            # Any scanner errors are fatal
            if "SystemScannerError" in str(e):
                recoverable = False
            # First time errors are fatal
            # User should be prompted to correct these
            if self.__first_run:
                recoverable = False
            raise ScannerError(Localization.Error.REMOTE_SERVER_SCAN.format(
                str(e).strip()),
                               recoverable=recoverable)

        try:
            remote_files = pickle.loads(out)
        except pickle.UnpicklingError as err:
            self.logger.error("Unpickling error: {}\n{}".format(str(err), out))
            raise ScannerError(Localization.Error.REMOTE_SERVER_SCAN.format(
                "Invalid pickled data"),
                               recoverable=False)

        self.__first_run = False
        return remote_files

    def _install_scanfs(self):
        # Check md5sum on remote to see if we can skip installation
        with open(self.__local_path_to_scan_script, "rb") as f:
            local_md5sum = hashlib.md5(f.read()).hexdigest()
        self.logger.debug("Local scanfs md5sum = {}".format(local_md5sum))
        try:
            out = self.__ssh.shell(
                "md5sum {} | awk '{{print $1}}' || echo".format(
                    self.__remote_path_to_scan_script))
            out = out.decode()
            if out == local_md5sum:
                self.logger.info(
                    "Skipping remote scanfs installation: already installed")
                return
        except SshcpError as e:
            self.logger.exception("Caught scp exception")
            raise ScannerError(Localization.Error.REMOTE_SERVER_INSTALL.format(
                str(e).strip()),
                               recoverable=False)

        # Go ahead and install
        self.logger.info("Installing local:{} to remote:{}".format(
            self.__local_path_to_scan_script,
            self.__remote_path_to_scan_script))
        if not os.path.isfile(self.__local_path_to_scan_script):
            raise ScannerError(Localization.Error.REMOTE_SERVER_SCAN.format(
                "Failed to find scanfs executable at {}".format(
                    self.__local_path_to_scan_script)),
                               recoverable=False)
        try:
            self.__ssh.copy(local_path=self.__local_path_to_scan_script,
                            remote_path=self.__remote_path_to_scan_script)
        except SshcpError as e:
            self.logger.exception("Caught scp exception")
            raise ScannerError(Localization.Error.REMOTE_SERVER_INSTALL.format(
                str(e).strip()),
                               recoverable=False)
예제 #17
0
 def test_ctor(self):
     sshcp = Sshcp(host=self.host, port=self.port)
     self.assertIsNotNone(sshcp)