def test_before_and_after_dataset_saved_hooks_sequential_runner(
            self, mock_session, caplog, dummy_dataframe):
        context = mock_session.load_context()
        context.catalog.save("cars", dummy_dataframe)
        mock_session.run(node_names=["node1"])

        # test before dataset saved hook
        before_dataset_saved_calls = [
            record for record in caplog.records
            if record.funcName == "before_dataset_saved"
        ]
        assert len(before_dataset_saved_calls) == 1
        call_record = before_dataset_saved_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record, ["dataset_name", "data"])

        assert call_record.dataset_name == "planes"
        assert call_record.data.to_dict() == dummy_dataframe.to_dict()

        # test after dataset saved hook
        after_dataset_saved_calls = [
            record for record in caplog.records
            if record.funcName == "after_dataset_saved"
        ]
        assert len(after_dataset_saved_calls) == 1
        call_record = after_dataset_saved_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record, ["dataset_name", "data"])

        assert call_record.dataset_name == "planes"
        assert call_record.data.to_dict() == dummy_dataframe.to_dict()
    def test_before_and_after_node_run_hooks_sequential_runner(
            self, caplog, mock_session, dummy_dataframe):
        context = mock_session.load_context()
        catalog = context.catalog
        catalog.save("cars", dummy_dataframe)
        mock_session.run(node_names=["node1"])

        # test before node run hook
        before_node_run_calls = [
            record for record in caplog.records
            if record.funcName == "before_node_run"
        ]
        assert len(before_node_run_calls) == 1
        call_record = before_node_run_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record, ["node", "catalog", "inputs", "is_async", "run_id"])
        # sanity check a couple of important parameters
        assert call_record.inputs["cars"].to_dict() == dummy_dataframe.to_dict(
        )
        assert call_record.run_id == mock_session.session_id

        # test after node run hook
        after_node_run_calls = [
            record for record in caplog.records
            if record.funcName == "after_node_run"
        ]
        assert len(after_node_run_calls) == 1
        call_record = after_node_run_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record,
            ["node", "catalog", "inputs", "outputs", "is_async", "run_id"])
        # sanity check a couple of important parameters
        assert call_record.outputs["planes"].to_dict(
        ) == dummy_dataframe.to_dict()
        assert call_record.run_id == mock_session.session_id
    def test_before_and_after_pipeline_run_hooks(self, caplog, mock_session,
                                                 dummy_dataframe):
        context = mock_session.load_context()
        catalog = context.catalog
        default_pipeline = context.pipeline
        catalog.save("cars", dummy_dataframe)
        catalog.save("boats", dummy_dataframe)
        mock_session.run()

        # test before pipeline run hook
        before_pipeline_run_calls = [
            record for record in caplog.records
            if record.funcName == "before_pipeline_run"
        ]
        assert len(before_pipeline_run_calls) == 1
        call_record = before_pipeline_run_calls[0]
        assert call_record.pipeline is default_pipeline
        _assert_hook_call_record_has_expected_parameters(
            call_record, ["pipeline", "catalog", "run_params"])

        # test after pipeline run hook
        after_pipeline_run_calls = [
            record for record in caplog.records
            if record.funcName == "after_pipeline_run"
        ]
        assert len(after_pipeline_run_calls) == 1
        call_record = after_pipeline_run_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record, ["pipeline", "catalog", "run_params"])
        assert call_record.pipeline is default_pipeline
    def test_on_pipeline_error_hook(self, caplog, mock_session):
        with pytest.raises(ValueError, match="broken"):
            mock_session.run()

        on_pipeline_error_calls = [
            record for record in caplog.records
            if record.funcName == "on_pipeline_error"
        ]
        assert len(on_pipeline_error_calls) == 1
        call_record = on_pipeline_error_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record, ["error", "run_params", "pipeline", "catalog"])
        expected_error = ValueError("broken")
        assert_exceptions_equal(call_record.error, expected_error)
    def test_on_node_error_hook_sequential_runner(self, caplog, mock_session):
        with pytest.raises(ValueError, match="broken"):
            mock_session.run(node_names=["node1"])

        on_node_error_calls = [
            record for record in caplog.records
            if record.funcName == "on_node_error"
        ]
        assert len(on_node_error_calls) == 1
        call_record = on_node_error_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record,
            ["error", "node", "catalog", "inputs", "is_async", "run_id"])
        expected_error = ValueError("broken")
        assert_exceptions_equal(call_record.error, expected_error)
    def test_on_node_error_hook_parallel_runner(self, mock_session,
                                                logs_listener):

        with pytest.raises(ValueError, match="broken"):
            mock_session.run(runner=ParallelRunner(max_workers=2),
                             node_names=["node1", "node2"])

        on_node_error_records = [
            r for r in logs_listener.logs if r.funcName == "on_node_error"
        ]
        assert len(on_node_error_records) == 2

        for call_record in on_node_error_records:
            _assert_hook_call_record_has_expected_parameters(
                call_record,
                ["error", "node", "catalog", "inputs", "is_async", "run_id"],
            )
            expected_error = ValueError("broken")
            assert_exceptions_equal(call_record.error, expected_error)