Ejemplo n.º 1
0
def fetch_log(host: Hostname,
              user: Username,
              task_id: int,
              tail: bool = False) -> Tuple[Iterator[str], str]:
    """Stateless, high-level interface for fetching log files from remote host.

    Seeks and reads files located under specific folder on remote host.
    Re-raises pssh exceptions.
    tail: whether to include full content or only last 10 lines
    """
    # TODO Path should be configurable from config files (see config.py)
    path = '~/TensorHiveLogs/task_{}.log'.format(task_id)
    program = 'tail' if tail else 'cat'
    command = '{} {}'.format(program, path)

    config, pconfig = ssh.build_dedicated_config_for(host, user)
    client = ssh.get_client(config, pconfig)
    output = ssh.run_command(client, command)

    if output[host].exception:
        # Propagage ssh exception
        raise output[host].exception
    if output[host].exit_code != 0:
        raise ExitCodeError(path)
    return output[host].stdout, path
Ejemplo n.º 2
0
    def spawn(self,
              client: ParallelSSHClient,
              name_appendix: Optional[str] = None) -> int:
        """Spawns command via ssh client.
        Returns:
            pid of the process
        """
        sess_name = 'tensorhive_task'
        log_name = None
        if name_appendix:
            sess_name += '_' + name_appendix
            log_name = 'task_' + name_appendix

        command = self._command_builder.spawn(self.command,
                                              session_name=sess_name,
                                              capture_output=True,
                                              custom_log_name=log_name,
                                              keep_alive=False)
        output = ssh.run_command(client, command)
        stdout = ssh.get_stdout(host=self.hostname, output=output)

        if not stdout:
            reason = output[self.hostname].exception
            raise ValueError(
                'Unable to acquire pid from empty stdout, reason: {}'.format(
                    reason))
        # FIXME May want to decouple it somehow
        # FIXME pop() may theoretically fail (never stumbled upon this issue)
        pid = stdout.split().pop()
        self.pid = int(pid)
        return self.pid
Ejemplo n.º 3
0
    def kill(self, client: ParallelSSHClient) -> int:
        """Kills the task using it's pid.

        Returns exit code of the operation
        """
        assert self.pid, 'You must first spawn the task or provide pid manually.'
        command = self._command_builder.kill(self.pid)
        output = ssh.run_command(client, command)
        exit_code = output[self.hostname].exit_code
        return exit_code
Ejemplo n.º 4
0
    def interrupt(self, client: ParallelSSHClient) -> int:
        """Interrupts the task gracefully by sending SIGINT signal

        Returns exit code of the operation
        """
        assert self.pid, 'You must first spawn the task or provide pid manually.'
        command = self._command_builder.interrupt(self.pid)
        output = ssh.run_command(client, command)
        exit_code = output[self.hostname].exit_code
        return exit_code
Ejemplo n.º 5
0
def running(host: Hostname, user: Username) -> List[int]:
    """Stateless, high-level interface for getting a list of running processes on remote host.

    Ignores sessions with names other than `pattern`
    Returns a list of pids
    """
    config, pconfig = ssh.build_dedicated_config_for(host, user)
    client = ssh.get_client(config, pconfig)
    pattern = '.*tensorhive_task.*'
    command = ScreenCommandBuilder.get_active_sessions(pattern)
    output = ssh.run_command(client, command)
    stdout = ssh.get_stdout(host, output)
    if not stdout:
        return []

    # '4321.foobar_session' -> 4321
    pid_from_session_name = lambda name: int(name.split('.')[0])
    pids = [pid_from_session_name(line) for line in stdout.split('\n')]
    log.debug('Running pids: {}'.format(pids))
    return pids
Ejemplo n.º 6
0
    def trigger_action(self, violation_data: Dict[str, Any]) -> None:
        username = violation_data['INTRUDER_USERNAME']

        for hostname in violation_data['VIOLATION_PIDS']:
            connection = violation_data['SSH_CONNECTIONS'][hostname]

            for pid in violation_data['VIOLATION_PIDS'][hostname]:
                command = 'sudo kill {}'.format(pid)
                connection.run_command(command)

                log.warning(
                    'Sudo killing process {} on host {}, user: {}'.format(
                        pid, hostname, username))
                output = ssh.run_command(connection, command)

                if output[hostname].exception:
                    e = output[hostname].exception
                    log.warning(
                        'Cannot kill process on host {}, user: {}, reason: {}'.
                        format(hostname, username, e))
Ejemplo n.º 7
0
    def trigger_action(self, violation_data: Dict[str, Any]) -> None:

        username = violation_data['INTRUDER_USERNAME']

        for hostname in violation_data['VIOLATION_PIDS']:
            config, pconfig = ssh.build_dedicated_config_for(
                hostname, username)
            client = ssh.get_client(config, pconfig)

            for pid in violation_data['VIOLATION_PIDS'][hostname]:
                command = 'kill {}'.format(pid)

                log.warning('Killing process {} on host {}, user: {}'.format(
                    pid, hostname, username))
                output = ssh.run_command(client, command)

                if output[hostname].exception:
                    e = output[hostname].exception
                    log.warning(
                        'Cannot kill process on host {}, user: {}, reason: {}'.
                        format(hostname, username, e))