示例#1
0
    def test_get_conn_from_connection(self, mock_get_connection, mock_protocol):
        connection = mock_get_connection.return_value
        winrm_hook = WinRMHook(ssh_conn_id='conn_id')

        winrm_hook.get_conn()

        mock_get_connection.assert_called_once_with(winrm_hook.ssh_conn_id)
        mock_protocol.assert_called_once_with(
            endpoint=str(connection.extra_dejson['endpoint']),
            transport=str(connection.extra_dejson['transport']),
            username=connection.login,
            password=connection.password,
            service=str(connection.extra_dejson['service']),
            keytab=str(connection.extra_dejson['keytab']),
            ca_trust_path=str(connection.extra_dejson['ca_trust_path']),
            cert_pem=str(connection.extra_dejson['cert_pem']),
            cert_key_pem=str(connection.extra_dejson['cert_key_pem']),
            server_cert_validation=str(connection.extra_dejson['server_cert_validation']),
            kerberos_delegation=str(connection.extra_dejson['kerberos_delegation']).lower() == 'true',
            read_timeout_sec=int(connection.extra_dejson['read_timeout_sec']),
            operation_timeout_sec=int(connection.extra_dejson['operation_timeout_sec']),
            kerberos_hostname_override=str(connection.extra_dejson['kerberos_hostname_override']),
            message_encryption=str(connection.extra_dejson['message_encryption']),
            credssp_disable_tlsv1_2=str(connection.extra_dejson['credssp_disable_tlsv1_2']).lower() == 'true',
            send_cbt=str(connection.extra_dejson['send_cbt']).lower() == 'true',
        )
示例#2
0
    def test_get_conn_no_endpoint(self, mock_protocol):
        winrm_hook = WinRMHook(remote_host='host', password='******')

        winrm_hook.get_conn()

        self.assertEqual(
            f'http://{winrm_hook.remote_host}:{winrm_hook.remote_port}/wsman',
            winrm_hook.endpoint)
示例#3
0
    def test_get_conn_exists(self, mock_protocol):
        winrm_hook = WinRMHook()
        winrm_hook.client = mock_protocol.return_value.open_shell.return_value

        conn = winrm_hook.get_conn()

        self.assertEqual(conn, winrm_hook.client)
示例#4
0
class WinRMOperator(BaseOperator):
    """
    WinRMOperator to execute commands on given remote host using the winrm_hook.

    :param winrm_hook: predefined ssh_hook to use for remote execution
    :type winrm_hook: airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook
    :param ssh_conn_id: connection id from airflow Connections
    :type ssh_conn_id: str
    :param remote_host: remote host to connect
    :type remote_host: str
    :param command: command to execute on remote host. (templated)
    :type command: str
    :param ps_path: path to powershell, `powershell` for v5.1- and `pwsh` for v6+.
        If specified, it will execute the command as powershell script.
    :type ps_path: str
    :param output_encoding: the encoding used to decode stout and stderr
    :type output_encoding: str
    :param timeout: timeout for executing the command.
    :type timeout: int
    """

    template_fields = ('command', )
    template_fields_renderers = {"command": "powershell"}

    def __init__(
        self,
        *,
        winrm_hook: Optional[WinRMHook] = None,
        ssh_conn_id: Optional[str] = None,
        remote_host: Optional[str] = None,
        command: Optional[str] = None,
        ps_path: Optional[str] = None,
        output_encoding: str = 'utf-8',
        timeout: int = 10,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.winrm_hook = winrm_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.command = command
        self.ps_path = ps_path
        self.output_encoding = output_encoding
        self.timeout = timeout

    def execute(self, context: dict) -> Union[list, str]:
        if self.ssh_conn_id and not self.winrm_hook:
            self.log.info("Hook not found, creating...")
            self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id)

        if not self.winrm_hook:
            raise AirflowException(
                "Cannot operate without winrm_hook or ssh_conn_id.")

        if self.remote_host is not None:
            self.winrm_hook.remote_host = self.remote_host

        if not self.command:
            raise AirflowException(
                "No command specified so nothing to execute here.")

        winrm_client = self.winrm_hook.get_conn()

        try:
            if self.ps_path is not None:
                self.log.info("Running command as powershell script: '%s'...",
                              self.command)
                encoded_ps = b64encode(
                    self.command.encode('utf_16_le')).decode('ascii')
                command_id = self.winrm_hook.winrm_protocol.run_command(  # type: ignore[attr-defined]
                    winrm_client,
                    f'{self.ps_path} -encodedcommand {encoded_ps}')
            else:
                self.log.info("Running command: '%s'...", self.command)
                command_id = self.winrm_hook.winrm_protocol.run_command(  # type: ignore[attr-defined]
                    winrm_client, self.command)

            # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
            stdout_buffer = []
            stderr_buffer = []
            command_done = False
            while not command_done:
                try:

                    (
                        stdout,
                        stderr,
                        return_code,
                        command_done,
                    ) = self.winrm_hook.winrm_protocol._raw_get_command_output(  # type: ignore[attr-defined]
                        winrm_client, command_id)

                    # Only buffer stdout if we need to so that we minimize memory usage.
                    if self.do_xcom_push:
                        stdout_buffer.append(stdout)
                    stderr_buffer.append(stderr)

                    for line in stdout.decode(
                            self.output_encoding).splitlines():
                        self.log.info(line)
                    for line in stderr.decode(
                            self.output_encoding).splitlines():
                        self.log.warning(line)
                except WinRMOperationTimeoutError:
                    # this is an expected error when waiting for a
                    # long-running process, just silently retry
                    pass

            self.winrm_hook.winrm_protocol.cleanup_command(  # type: ignore[attr-defined]
                winrm_client, command_id)
            self.winrm_hook.winrm_protocol.close_shell(
                winrm_client)  # type: ignore[attr-defined]

        except Exception as e:
            raise AirflowException(f"WinRM operator error: {str(e)}")

        if return_code == 0:
            # returning output if do_xcom_push is set
            enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
            if enable_pickling:
                return stdout_buffer
            else:
                return b64encode(b''.join(stdout_buffer)).decode(
                    self.output_encoding)
        else:
            error_msg = "Error running cmd: {}, return code: {}, error: {}".format(
                self.command, return_code,
                b''.join(stderr_buffer).decode(self.output_encoding))
            raise AirflowException(error_msg)
示例#5
0
class WinRMOperator(BaseOperator):
    """
    WinRMOperator to execute commands on given remote host using the winrm_hook.

    :param winrm_hook: predefined ssh_hook to use for remote execution
    :type winrm_hook: airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook
    :param ssh_conn_id: connection id from airflow Connections
    :type ssh_conn_id: str
    :param remote_host: remote host to connect
    :type remote_host: str
    :param command: command to execute on remote host. (templated)
    :type command: str
    :param timeout: timeout for executing the command.
    :type timeout: int
    """
    template_fields = ('command', )

    @apply_defaults
    def __init__(self,
                 *,
                 winrm_hook=None,
                 ssh_conn_id=None,
                 remote_host=None,
                 command=None,
                 timeout=10,
                 **kwargs):
        super().__init__(**kwargs)
        self.winrm_hook = winrm_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.command = command
        self.timeout = timeout

    def execute(self, context):
        if self.ssh_conn_id and not self.winrm_hook:
            self.log.info("Hook not found, creating...")
            self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id)

        if not self.winrm_hook:
            raise AirflowException(
                "Cannot operate without winrm_hook or ssh_conn_id.")

        if self.remote_host is not None:
            self.winrm_hook.remote_host = self.remote_host

        if not self.command:
            raise AirflowException(
                "No command specified so nothing to execute here.")

        winrm_client = self.winrm_hook.get_conn()

        # pylint: disable=too-many-nested-blocks
        try:
            self.log.info("Running command: '%s'...", self.command)
            command_id = self.winrm_hook.winrm_protocol.run_command(
                winrm_client, self.command)

            # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
            stdout_buffer = []
            stderr_buffer = []
            command_done = False
            while not command_done:
                try:
                    # pylint: disable=protected-access
                    stdout, stderr, return_code, command_done = \
                        self.winrm_hook.winrm_protocol._raw_get_command_output(
                            winrm_client,
                            command_id
                        )

                    # Only buffer stdout if we need to so that we minimize memory usage.
                    if self.do_xcom_push:
                        stdout_buffer.append(stdout)
                    stderr_buffer.append(stderr)

                    for line in stdout.decode('utf-8').splitlines():
                        self.log.info(line)
                    for line in stderr.decode('utf-8').splitlines():
                        self.log.warning(line)
                except WinRMOperationTimeoutError:
                    # this is an expected error when waiting for a
                    # long-running process, just silently retry
                    pass

            self.winrm_hook.winrm_protocol.cleanup_command(
                winrm_client, command_id)
            self.winrm_hook.winrm_protocol.close_shell(winrm_client)

        except Exception as e:
            raise AirflowException("WinRM operator error: {0}".format(str(e)))

        if return_code == 0:
            # returning output if do_xcom_push is set
            enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
            if enable_pickling:
                return stdout_buffer
            else:
                return b64encode(b''.join(stdout_buffer)).decode('utf-8')
        else:
            error_msg = "Error running cmd: {0}, return code: {1}, error: {2}".format(
                self.command, return_code,
                b''.join(stderr_buffer).decode('utf-8'))
            raise AirflowException(error_msg)
示例#6
0
    def test_get_conn_no_username(self, mock_protocol, mock_getuser):
        winrm_hook = WinRMHook(remote_host='host', password='******')

        winrm_hook.get_conn()

        self.assertEqual(mock_getuser.return_value, winrm_hook.username)