コード例 #1
0
    def __init__(self, source, product, dag, name, client=None, params=None):
        super().__init__(source, product, dag, name, params)

        self.client = client or self.dag.clients.get(type(self))

        if self.client is None:
            self.client = ShellClient()
コード例 #2
0
def test_shell_client(tmp_directory):
    path = Path(tmp_directory, 'a_file')

    client = ShellClient()
    code = """
    touch a_file
    """
    assert not path.exists()

    client.execute(code)

    assert path.exists()
コード例 #3
0
ファイル: tasks.py プロジェクト: israelrico007/ploomber
class ShellScript(Task):
    """Execute a shell script in a shell

    Parameters
    ----------
    source: str or pathlib.Path
        Script source, if str, the content is interpreted as the actual
        script, if pathlib.Path, the content of the file is loaded. The
        souce code must have the {{product}} tag
    product: ploomber.products.Product
        Product generated upon successful execution
    dag: ploomber.DAG
        A DAG to add this task to
    name: str
        A str to indentify this task. Should not already exist in the dag
    client: ploomber.clients.ShellClient or RemoteShellClient, optional
        The client used to connect to the database. Only required
        if no dag-level client has been declared using dag.clients[class]
    params: dict, optional
        Parameters to pass to the script, by default, the callable will
        be executed with a "product" (which will contain the product object).
        It will also include a "upstream" parameter if the task has upstream
        dependencies along with any parameters declared here. The source
        code is converted to a jinja2.Template for passing parameters,
        refer to jinja2 documentation for details
    """
    def __init__(self,
                 source,
                 product,
                 dag,
                 name=None,
                 client=None,
                 params=None):
        kwargs = dict(hot_reload=dag._params.hot_reload)
        self._source = type(self)._init_source(source, kwargs)
        super().__init__(product, dag, name, params)

        self.client = client or self.dag.clients.get(type(self))

        if self.client is None:
            self.client = ShellClient()

    @staticmethod
    def _init_source(source, kwargs):
        required = {
            'product': ('ShellScript must include {{product}} in '
                        'its source')
        }

        return GenericSource(source, **kwargs, required=required)

    def run(self):
        self.client.execute(str(self.source))
コード例 #4
0
def test_shell_client_with_custom_template(tmp_directory):
    path = Path(tmp_directory, 'a_file')

    client = ShellClient(run_template='ruby {{path_to_code}}')
    code = """
    require 'fileutils'
    FileUtils.touch "a_file"
    """
    assert not path.exists()

    client.execute(code)

    assert path.exists()
コード例 #5
0
def test_shell_client_tmp_file_is_deleted(tmp_directory, monkeypatch):
    client = ShellClient()
    code = """
    echo 'hello'
    """
    mock_unlink = Mock()
    monkeypatch.setattr(shell.Path, 'unlink', mock_unlink)
    mock_res = Mock()
    mock_res.returncode = 0
    mock_run_call = Mock(return_value=mock_res)
    monkeypatch.setattr(shell.subprocess, 'run', mock_run_call)

    client.execute(code)

    mock_unlink.assert_called_once()
コード例 #6
0
ファイル: tasks.py プロジェクト: israelrico007/ploomber
    def __init__(self,
                 source,
                 product,
                 dag,
                 name=None,
                 client=None,
                 params=None):
        kwargs = dict(hot_reload=dag._params.hot_reload)
        self._source = type(self)._init_source(source, kwargs)
        super().__init__(product, dag, name, params)

        self.client = client or self.dag.clients.get(type(self))

        if self.client is None:
            self.client = ShellClient()
コード例 #7
0
class ShellScript(Task):
    """Execute a shell script in a shell

    Parameters
    ----------
    source: str or pathlib.Path
        Script source, if str, the content is interpreted as the actual
        script, if pathlib.Path, the content of the file is loaded
    product: ploomber.products.Product
        Product generated upon successful execution
    dag: ploomber.DAG
        A DAG to add this task to
    name: str
        A str to indentify this task. Should not already exist in the dag
    client: ploomber.clients.ShellClient or RemoteShellClient, optional
        The client used to connect to the database. Only required
        if no dag-level client has been declared using dag.clients[class]
    params: dict, optional
        Parameters to pass to the script, by default, the callable will
        be executed with a "product" (which will contain the product object).
        It will also include a "upstream" parameter if the task has upstream
        dependencies along with any parameters declared here. The source
        code is converted to a jinja2.Template for passing parameters,
        refer to jinja2 documentation for details
    """
    def __init__(self, source, product, dag, name, client=None, params=None):
        super().__init__(source, product, dag, name, params)

        self.client = client or self.dag.clients.get(type(self))

        if self.client is None:
            self.client = ShellClient()

    def _init_source(self, source):
        source = GenericSource(str(source))

        if not source.needs_render:
            raise SourceInitializationError('The source for this task '
                                            'must be a template since the '
                                            'product will be passed as '
                                            'parameter')

        return source

    def run(self):
        self.client.execute(str(self.source))
コード例 #8
0
 def client(self):
     try:
         client = super().client
     except MissingClientError:
         self._client = ShellClient()
         return self._client
     else:
         return client
コード例 #9
0
def test_shell_client_execute(run_template, tmp_directory, monkeypatch):
    if run_template:
        client = ShellClient(run_template=run_template)
        expected_command = run_template.split(' ')[0]
    else:
        client = ShellClient()
        expected_command = 'bash'

    code = """
    echo 'hello'
    """

    mock_res = Mock()
    mock_res.returncode = 0
    mock_run_call = Mock(return_value=mock_res)

    monkeypatch.setattr(shell.subprocess, 'run', mock_run_call)
    # prevent tmp file from being removed so we can check contents
    monkeypatch.setattr(shell.Path, 'unlink', Mock())

    client.execute(code)

    cmd, path = mock_run_call.call_args[0][0]

    assert cmd == expected_command
    assert Path(path).read_text() == code
コード例 #10
0
def test_task_level_shell_client(tmp_directory, monkeypatch):
    path = Path(tmp_directory, 'a_file')
    dag = DAG()
    client = ShellClient(run_template='ruby {{path_to_code}}')
    dag.clients[ShellScript] = client

    ShellScript("""
    require 'fileutils'
    FileUtils.touch "{{product}}"
    """,
                product=File(path),
                dag=dag,
                name='ruby_script')

    mock = Mock(wraps=client.execute)
    monkeypatch.setattr(client, 'execute', mock)

    mock_res = Mock()
    mock_res.returncode = 0

    def side_effect(*args, **kwargs):
        Path('a_file').touch()
        return mock_res

    mock_run_call = Mock(side_effect=side_effect)
    monkeypatch.setattr(shell.subprocess, 'run', mock_run_call)
    # prevent tmp file from being removed so we can check contents
    monkeypatch.setattr(shell.Path, 'unlink', Mock())

    dag.build()

    mock.assert_called_once()

    cmd, path_arg = mock_run_call.call_args[0][0]
    kwargs = mock_run_call.call_args[1]

    expected_code = """
    require 'fileutils'
    FileUtils.touch "{path}"
    """.format(path=path)

    assert cmd == 'ruby'
    assert Path(path_arg).read_text() == expected_code
    assert kwargs == {
        'stderr': subprocess.PIPE,
        'stdout': subprocess.PIPE,
        'shell': False
    }
コード例 #11
0
def test_custom_client_in_dag(tmp_directory):
    path = Path(tmp_directory, 'a_file')

    dag = DAG()

    client = ShellClient(run_template='ruby {{path_to_code}}')

    dag.clients[ShellScript] = client

    ShellScript("""
    require 'fileutils'
    FileUtils.touch "{{product}}"
    """,
                product=File(path),
                dag=dag,
                name='ruby_script')

    assert not path.exists()

    dag.build()

    assert path.exists()