Esempio n. 1
0
    def new_get_run_state(_run_id):
        calls["num_calls"] += 1

        if calls["num_calls"] == 1:
            return DatabricksRunState(DatabricksRunLifeCycleState.Pending, None, None,)
        elif calls["num_calls"] == 2:
            return DatabricksRunState(DatabricksRunLifeCycleState.Running, None, None,)
        else:
            return calls["final_state"]
Esempio n. 2
0
def test_databricks_wait_for_run(mock_submit_run, databricks_run_config):
    mock_submit_run.return_value = {'run_id': 1}

    context = create_test_pipeline_execution_context()
    runner = DatabricksJobRunner(HOST, TOKEN, poll_interval_sec=0.01)
    task = databricks_run_config.pop('task')
    databricks_run_id = runner.submit_run(databricks_run_config, task)

    calls = {
        'num_calls':
        0,
        'final_state':
        DatabricksRunState(
            DatabricksRunLifeCycleState.Terminated,
            DatabricksRunResultState.Success,
            'Finished',
        ),
    }

    def new_get_run_state(_run_id):
        calls['num_calls'] += 1

        if calls['num_calls'] == 1:
            return DatabricksRunState(
                DatabricksRunLifeCycleState.Pending,
                None,
                None,
            )
        elif calls['num_calls'] == 2:
            return DatabricksRunState(
                DatabricksRunLifeCycleState.Running,
                None,
                None,
            )
        else:
            return calls['final_state']

    with mock.patch.object(runner.client,
                           'get_run_state',
                           new=new_get_run_state):
        runner.wait_for_run_to_complete(context.log, databricks_run_id)

    calls['num_calls'] = 0
    calls['final_state'] = DatabricksRunState(
        DatabricksRunLifeCycleState.Terminated,
        DatabricksRunResultState.Failed,
        'Failed',
    )
    with pytest.raises(DatabricksError) as exc_info:
        with mock.patch.object(runner.client,
                               'get_run_state',
                               new=new_get_run_state):
            runner.wait_for_run_to_complete(context.log, databricks_run_id)
    assert 'Run 1 failed with result state' in str(exc_info.value)
Esempio n. 3
0
def test_pyspark_databricks(mock_get_run_state, mock_get_step_events,
                            mock_put_file, mock_read_file, mock_submit_run):
    mock_submit_run.return_value = 12345
    mock_read_file.return_value = "somefilecontents".encode()

    running_state = DatabricksRunState(DatabricksRunLifeCycleState.Running,
                                       None, "")
    final_state = DatabricksRunState(DatabricksRunLifeCycleState.Terminated,
                                     DatabricksRunResultState.Success, "")
    mock_get_run_state.side_effect = [running_state] * 5 + [final_state]

    with instance_for_test() as instance:
        execute_pipeline(pipeline=reconstructable(define_do_nothing_pipe),
                         mode="local",
                         instance=instance)
        mock_get_step_events.return_value = [
            record.event_log_entry for record in instance.get_event_records()
            if record.event_log_entry.step_key == "do_nothing_solid"
        ]
    config = BASE_DATABRICKS_PYSPARK_STEP_LAUNCHER_CONFIG.copy()
    config.pop("local_pipeline_package_path")
    result = execute_pipeline(
        pipeline=reconstructable(define_do_nothing_pipe),
        mode="test",
        run_config={
            "resources": {
                "pyspark_step_launcher": {
                    "config":
                    deep_merge_dicts(
                        config,
                        {
                            "databricks_host":
                            "",
                            "databricks_token":
                            "",
                            "poll_interval_sec":
                            0.1,
                            "local_dagster_job_package_path":
                            os.path.abspath(os.path.dirname(__file__)),
                        },
                    ),
                },
            },
        },
    )
    assert result.success
    assert mock_get_run_state.call_count == 6
    assert mock_get_step_events.call_count == 6
    assert mock_put_file.call_count == 4
    assert mock_read_file.call_count == 2
    assert mock_submit_run.call_count == 1
Esempio n. 4
0
def test_run_create_databricks_job_solid(
    mock_submit_run, mock_get_run_state, databricks_run_config
):
    @pipeline(
        mode_defs=[
            ModeDefinition(
                resource_defs={
                    "databricks_client": databricks_client.configured(
                        {"host": "a", "token": "fdshj"}
                    )
                }
            )
        ]
    )
    def test_pipe():
        create_databricks_job_solid(num_inputs=0).configured(
            {"job": databricks_run_config, "poll_interval_sec": 0.01}, name="test"
        )()

    RUN_ID = 1
    mock_submit_run.return_value = RUN_ID
    mock_get_run_state.return_value = DatabricksRunState(
        state_message="",
        result_state=DatabricksRunResultState.Success,
        life_cycle_state=DatabricksRunLifeCycleState.Terminated,
    )

    result = execute_pipeline(test_pipe)
    assert result.success

    assert mock_submit_run.call_count == 1
    assert mock_submit_run.call_args_list[0] == (databricks_run_config,)
    assert mock_get_run_state.call_count == 1
    assert mock_get_run_state.call_args[0][0] == RUN_ID