def test_on_pipeline_error_hook(self, caplog, mock_session):
        with pytest.raises(ValueError, match="broken"):
            mock_session.run()

        on_pipeline_error_calls = [
            record for record in caplog.records
            if record.funcName == "on_pipeline_error"
        ]
        assert len(on_pipeline_error_calls) == 1
        call_record = on_pipeline_error_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record, ["error", "run_params", "pipeline", "catalog"])
        expected_error = ValueError("broken")
        assert_exceptions_equal(call_record.error, expected_error)
    def test_on_node_error_hook_sequential_runner(self, caplog, mock_session):
        with pytest.raises(ValueError, match="broken"):
            mock_session.run(node_names=["node1"])

        on_node_error_calls = [
            record for record in caplog.records
            if record.funcName == "on_node_error"
        ]
        assert len(on_node_error_calls) == 1
        call_record = on_node_error_calls[0]
        _assert_hook_call_record_has_expected_parameters(
            call_record,
            ["error", "node", "catalog", "inputs", "is_async", "run_id"])
        expected_error = ValueError("broken")
        assert_exceptions_equal(call_record.error, expected_error)
    def test_on_node_error_hook_parallel_runner(self, mock_session,
                                                logs_listener):

        with pytest.raises(ValueError, match="broken"):
            mock_session.run(runner=ParallelRunner(max_workers=2),
                             node_names=["node1", "node2"])

        on_node_error_records = [
            r for r in logs_listener.logs if r.funcName == "on_node_error"
        ]
        assert len(on_node_error_records) == 2

        for call_record in on_node_error_records:
            _assert_hook_call_record_has_expected_parameters(
                call_record,
                ["error", "node", "catalog", "inputs", "is_async", "run_id"],
            )
            expected_error = ValueError("broken")
            assert_exceptions_equal(call_record.error, expected_error)