示例#1
0
def test_node_inputs():
    env = make_test_env()
    g = Graph(env)
    df = Snap(snap_t1_source)
    node = g.create_node(key="node", snap=df)
    df = Snap(snap_t1_sink)
    node1 = g.create_node(key="node1", snap=df, input=node)
    pi = node1.get_interface()
    assert len(pi.inputs) == 1
    assert pi.output == make_default_output()
    assert list(node1.declared_inputs.keys()) == ["input"]
示例#2
0
def test_node_inputs():
    env = make_test_env()
    g = Graph(env)
    df = datafunction(function_t1_source)
    node = g.create_node(key="node", function=df)
    df = datafunction(function_t1_sink)
    node1 = g.create_node(key="node1", function=df, input=node)
    pi = node1.get_interface()
    assert len(pi.inputs) == 1
    assert pi.outputs == DEFAULT_OUTPUTS
    assert list(node1.declared_inputs.keys()) == ["input"]
示例#3
0
def test_node_inputs():
    env = make_test_env()
    g = Graph(env)
    df = pipe(pipe_t1_source)
    node = g.create_node(key="node", pipe=df)
    df = pipe(pipe_t1_sink)
    node1 = g.create_node(key="node1", pipe=df, upstream=node)
    pi = node1.get_interface()
    assert len(pi.inputs) == 1
    assert pi.output == make_default_output_annotation()
    assert list(node1.declared_inputs.keys()) == ["input"]
def test_repeated_runs():
    env = get_env()
    g = Graph(env)
    s = env._local_python_storage
    # Initial graph
    N = 2 * 4
    g.create_node(key="source",
                  pipe=customer_source,
                  config={"total_records": N})
    metrics = g.create_node(key="metrics",
                            pipe=shape_metrics,
                            upstream="source")
    # Run first time
    output = env.produce("metrics", g, target_storage=s)
    assert output.nominal_schema_key.endswith("Metric")
    records = output.as_records()
    expected_records = [
        {
            "metric": "row_count",
            "value": 4
        },
        {
            "metric": "col_count",
            "value": 3
        },
    ]
    assert records == expected_records
    # Run again, should get next batch
    output = env.produce("metrics", g, target_storage=s)
    records = output.as_records()
    assert records == expected_records
    # Test latest_output
    output = env.latest_output(metrics)
    records = output.as_records()
    assert records == expected_records
    # Run again, should be exhausted
    output = env.produce("metrics", g, target_storage=s)
    assert output is None
    # Run again, should still be exhausted
    output = env.produce("metrics", g, target_storage=s)
    assert output is None

    # now add new node and process all at once
    g.create_node(key="new_accumulator",
                  pipe="core.dataframe_accumulator",
                  upstream="source")
    output = env.produce("new_accumulator", g, target_storage=s)
    records = output.as_records()
    assert len(records) == N
    output = env.produce("new_accumulator", g, target_storage=s)
    assert output is None
示例#5
0
def test_worker_output():
    env = make_test_env()
    env.add_module(core)
    g = Graph(env)
    # env.add_storage("python://test")
    with env.session_scope() as sess:
        rt = env.runtimes[0]
        # TODO: this is error because no data copy between SAME storage engines (but DIFFERENT storage urls) currently
        # ec = env.get_run_context(g, current_runtime=rt, target_storage=env.storages[0])
        ec = env.get_run_context(g, current_runtime=rt, target_storage=rt.as_storage())
        output_alias = "node_output"
        node = g.create_node(key="node", pipe=pipe_dl_source, output_alias=output_alias)
        w = Worker(ec)
        dfi_mgr = NodeInterfaceManager(ec, sess, node)
        bdfi = dfi_mgr.get_bound_interface()
        r = Executable(
            node.key,
            CompiledPipe(node.pipe.key, node.pipe),
            bdfi,
        )
        run_result = w.execute(r)
        outputblock = run_result.output_block
        assert outputblock is not None
        outputblock = sess.merge(outputblock)
        block = outputblock.as_managed_data_block(ec, sess)
        assert block.as_records() == mock_dl_output
        assert block.nominal_schema is TestSchema4
        assert len(block.realized_schema.fields) == len(TestSchema4.fields)
        # Test alias was created correctly
        assert (
            sess.query(Alias).filter(Alias.alias == output_alias).first().data_block_id
            == block.data_block_id
        )
示例#6
0
def test_declared_schema_translation():
    ec = make_test_run_context()
    env = ec.env
    g = Graph(env)
    translation = {"f1": "mapped_f1"}
    n1 = g.create_node(
        key="node1",
        function=function_t1_to_t2,
        input="n0",
        schema_translation=translation,
    )
    pi = n1.get_interface()
    # im = NodeInterfaceManager(ctx=ec, node=n1)
    block = DataBlockMetadata(
        nominal_schema_key="_test.TestSchema1",
        realized_schema_key="_test.TestSchema1",
    )
    # stream = block_as_stream(block, ec, pi.inputs[0].schema(env), translation)
    # bi = im.get_bound_stream_interface({"input": stream})
    # assert len(bi.inputs) == 1
    # input: StreamInput = bi.inputs[0]
    with env.md_api.begin():
        schema_translation = get_schema_translation(
            env,
            block.realized_schema(env),
            target_schema=env.get_schema(
                pi.get_single_non_recursive_input().schema_like),
            declared_schema_translation=translation,
        )
        assert schema_translation.as_dict() == translation
示例#7
0
def test_exe_output():
    env = make_test_env()
    env.add_module(core)
    g = Graph(env)
    # env.add_storage("python://test")
    # rt = env.runtimes[0]
    # TODO: this is error because no data copy between SAME storage engines (but DIFFERENT storage urls) currently
    # ec = env.get_run_context(g, current_runtime=rt, target_storage=env.storages[0])
    # ec = env.get_run_context(g, current_runtime=rt, target_storage=rt.as_storage())
    output_alias = "node_output"
    node = g.create_node(key="node",
                         snap=snap_dl_source,
                         output_alias=output_alias)
    exe = env.get_executable(node)
    result = ExecutionManager(exe).execute()
    with env.md_api.begin():
        block = result.get_output_block(env)
        assert block is not None
        assert block.as_records() == mock_dl_output
        assert block.nominal_schema is TestSchema4
        assert len(block.realized_schema.fields) == len(TestSchema4.fields)
        # Test alias was created correctly
        assert (env.md_api.execute(
            select(Alias).filter(Alias.alias == output_alias)).
                scalar_one_or_none().data_block_id == block.data_block_id)
        assert env.md_api.count(select(DataBlockLog)) == 1
        dbl = env.md_api.execute(select(DataBlockLog)).scalar_one_or_none()
        assert dbl.data_block_id == block.data_block_id
        assert dbl.direction == Direction.OUTPUT
示例#8
0
def test_natural_schema_translation():
    # TODO
    ec = make_test_run_context()
    env = ec.env
    g = Graph(env)
    translation = {"f1": "mapped_f1"}
    n1 = g.create_node(
        key="node1",
        function=function_t1_to_t2,
        input="n0",
        schema_translation=translation,
    )
    pi = n1.get_interface()
    # im = NodeInterfaceManager(ctx=ec, node=n1)
    block = DataBlockMetadata(
        nominal_schema_key="_test.TestSchema1",
        realized_schema_key="_test.TestSchema1",
    )
    with env.md_api.begin():
        schema_translation = get_schema_translation(
            env,
            block.realized_schema(env),
            target_schema=env.get_schema(
                pi.get_single_non_recursive_input().schema_like),
            declared_schema_translation=translation,
        )
        assert schema_translation.as_dict() == translation
示例#9
0
def test_function_failure():
    env = get_env()
    g = Graph(env)
    s = env._local_python_storage
    # Initial graph
    batches = 2
    cfg = {"batches": batches, "fail": True}
    source = g.create_node(customer_source, params=cfg)
    blocks = produce(source, graph=g, target_storage=s, env=env)
    assert len(blocks) == 1
    records = blocks[0].as_records()
    assert len(records) == 2
    with env.md_api.begin():
        assert env.md_api.count(select(DataFunctionLog)) == 1
        assert env.md_api.count(select(DataBlockLog)) == 1
        pl = env.md_api.execute(select(DataFunctionLog)).scalar_one_or_none()
        assert pl.node_key == source.key
        assert pl.graph_id == g.get_metadata_obj().hash
        assert pl.node_start_state == {}
        assert pl.node_end_state == {"records_imported": chunk_size}
        assert pl.function_key == source.function.key
        assert pl.function_params == cfg
        assert pl.error is not None
        assert FAIL_MSG in pl.error["error"]
        ns = env.md_api.execute(
            select(NodeState).filter(NodeState.node_key == pl.node_key)
        ).scalar_one_or_none()
        assert ns.state == {"records_imported": chunk_size}

    # Run again without failing, should see different result
    source.params["fail"] = False
    blocks = produce(source, graph=g, target_storage=s, env=env)
    assert len(blocks) == 1
    records = blocks[0].as_records()
    assert len(records) == batch_size
    with env.md_api.begin():
        assert env.md_api.count(select(DataFunctionLog)) == 2
        assert env.md_api.count(select(DataBlockLog)) == 2
        pl = (
            env.md_api.execute(
                select(DataFunctionLog).order_by(DataFunctionLog.completed_at.desc())
            )
            .scalars()
            .first()
        )
        assert pl.node_key == source.key
        assert pl.graph_id == g.get_metadata_obj().hash
        assert pl.node_start_state == {"records_imported": chunk_size}
        assert pl.node_end_state == {"records_imported": chunk_size + batch_size}
        assert pl.function_key == source.function.key
        assert pl.function_params == cfg
        assert pl.error is None
        ns = env.md_api.execute(
            select(NodeState).filter(NodeState.node_key == pl.node_key)
        ).scalar_one_or_none()
        assert ns.state == {"records_imported": chunk_size + batch_size}
示例#10
0
def test_node_no_inputs():
    env = make_test_env()
    g = Graph(env)
    df = datafunction(function_t1_source)
    node1 = g.create_node(key="node1", function=df)
    assert {node1: node1}[node1] is node1  # Test hash
    pi = node1.get_interface()
    assert pi.inputs == {}
    assert pi.outputs != {}
    assert node1.declared_inputs == {}
示例#11
0
def test_non_terminating_snap():
    def never_stop(input: Optional[DataBlock] = None) -> DataFrame:
        pass

    env = make_test_env()
    g = Graph(env)
    node = g.create_node(key="node", snap=never_stop)
    exe = env.get_executable(node)
    result = ExecutionManager(exe).execute()
    assert result.get_output_block(env) is None
示例#12
0
def test_node_no_inputs():
    env = make_test_env()
    g = Graph(env)
    df = pipe(pipe_t1_source)
    node1 = g.create_node(key="node1", pipe=df)
    assert {node1: node1}[node1] is node1  # Test hash
    pi = node1.get_interface()
    assert pi.inputs == []
    assert pi.output is not None
    assert node1.declared_inputs == {}
示例#13
0
def make_graph() -> Graph:
    env = make_test_env()
    env.add_module(core)
    g = Graph(env)
    g.create_node(key="node1", function=function_t1_source)
    g.node(key="node2", function=function_t1_source)
    g.node(key="node3", function=function_t1_to_t2, input="node1")
    g.node(key="node4", function=function_t1_to_t2, input="node2")
    g.node(key="node5", function=function_generic, input="node4")
    g.node(key="node6", function=function_self, input="node4")
    g.node(
        key="node7",
        function=function_multiple_input,
        inputs={
            "input": "node4",
            "other_t2": "node3"
        },
    )
    return g
示例#14
0
def test_alternate_apis():
    env = get_env()
    g = Graph(env)
    s = env._local_python_storage
    # Initial graph
    batches = 2
    source = g.create_node(customer_source, params={"batches": batches})
    metrics = g.create_node(shape_metrics, input=source)
    # Run first time
    blocks = produce(metrics, graph=g, target_storage=s, env=env)
    assert len(blocks) == 1
    output = blocks[0]
    assert output.nominal_schema_key.endswith("Metric")
    records = blocks[0].as_records()
    expected_records = [
        {"metric": "row_count", "value": batch_size},
        {"metric": "col_count", "value": 3},
    ]
    assert records == expected_records
示例#15
0
def test_non_terminating_function_with_reference_input():
    def never_stop(input: Optional[Reference]) -> DataFrame:
        # Does not use input but doesn't matter cause reference
        pass

    env = make_test_env()
    g = Graph(env)
    source = g.create_node(
        function="core.import_dataframe",
        params={"dataframe": pd.DataFrame({"a": range(10)})},
    )
    node = g.create_node(key="node", function=never_stop, input=source)
    exe = env.get_executable(source)
    # TODO: reference inputs need to log too? (So they know when to update)
    # with env.md_api.begin():
    #     assert env.md_api.count(select(DataBlockLog)) == 1
    result = ExecutionManager(exe).execute()
    exe = env.get_executable(node)
    result = ExecutionManager(exe).execute()
    assert result.get_output_block(env) is None
示例#16
0
def test_node_params():
    env = make_test_env()
    g = Graph(env)
    param_vals = []

    def function_ctx(ctx: DataFunctionContext, test: str):
        param_vals.append(test)

    n = g.create_node(key="ctx", function=function_ctx, params={"test": 1})
    env.run_node(n, g)
    assert param_vals == [1]
示例#17
0
def test_repeated_runs():
    env = get_env()
    g = Graph(env)
    s = env._local_python_storage
    # Initial graph
    batches = 2
    N = batches * batch_size
    g.create_node(key="source", function=customer_source, params={"batches": batches})
    metrics = g.create_node(key="metrics", function=shape_metrics, input="source")
    # Run first time
    blocks = env.produce("metrics", g, target_storage=s)
    assert blocks[0].nominal_schema_key.endswith("Metric")
    records = blocks[0].as_records()
    expected_records = [
        {"metric": "row_count", "value": batch_size},
        {"metric": "col_count", "value": 3},
    ]
    assert records == expected_records
    # Run again, should get next batch
    blocks = env.produce("metrics", g, target_storage=s)
    records = blocks[0].as_records()
    assert records == expected_records
    # Test latest_output
    block = env.get_latest_output(metrics)
    records = block.as_records()
    assert records == expected_records
    # Run again, should be exhausted
    blocks = env.produce("metrics", g, target_storage=s)
    assert len(blocks) == 0
    # Run again, should still be exhausted
    blocks = env.produce("metrics", g, target_storage=s)
    assert len(blocks) == 0

    # now add new node and process all at once
    g.create_node(key="new_accumulator", function="core.accumulator", input="source")
    blocks = env.produce("new_accumulator", g, target_storage=s)
    assert len(blocks) == 1
    records = blocks[0].as_records()
    assert len(records) == N
    blocks = env.produce("new_accumulator", g, target_storage=s)
    assert len(blocks) == 0
示例#18
0
def test_node_config():
    env = make_test_env()
    g = Graph(env)
    config_vals = []

    def pipe_ctx(ctx: PipeContext):
        config_vals.append(ctx.get_config_value("test"))

    n = g.create_node(key="ctx", pipe=pipe_ctx, config={"test": 1, "extra_arg": 2})
    with env.run(g) as exe:
        exe.execute(n)
    assert config_vals == [1]
示例#19
0
def test_non_terminating_pipe():
    def never_stop(input: Optional[DataBlock] = None) -> DataFrame:
        pass

    env = make_test_env()
    g = Graph(env)
    rt = env.runtimes[0]
    ec = env.get_run_context(g, current_runtime=rt)
    node = g.create_node(key="node", pipe=never_stop)
    em = ExecutionManager(ec)
    output = em.execute(node, to_exhaustion=True)
    assert output is None
def test_alternate_apis():
    env = get_env()
    g = Graph(env)
    s = env._local_python_storage
    # Initial graph
    N = 2 * 4
    source = g.create_node(customer_source, config={"total_records": N})
    metrics = g.create_node(shape_metrics, upstream=source)
    # Run first time
    output = produce(metrics, graph=g, target_storage=s, env=env)
    assert output.nominal_schema_key.endswith("Metric")
    records = output.as_records()
    expected_records = [
        {
            "metric": "row_count",
            "value": 4
        },
        {
            "metric": "col_count",
            "value": 3
        },
    ]
    assert records == expected_records
示例#21
0
def test_exe():
    env = make_test_env()
    g = Graph(env)
    node = g.create_node(key="node", snap=snap_t1_source)
    exe = env.get_executable(node)
    result = ExecutionManager(exe).execute()
    with env.md_api.begin():
        assert not result.output_blocks
        assert env.md_api.count(select(SnapLog)) == 1
        pl = env.md_api.execute(select(SnapLog)).scalar_one_or_none()
        assert pl.node_key == node.key
        assert pl.graph_id == g.get_metadata_obj().hash
        assert pl.node_start_state == {}
        assert pl.node_end_state == {}
        assert pl.snap_key == node.snap.key
        assert pl.snap_params == {}
示例#22
0
def test_node_params():
    env = make_test_env()
    g = Graph(env)
    param_vals = []

    @Param("test", "str")
    def snap_ctx(ctx: SnapContext):
        param_vals.append(ctx.get_param("test"))

    n = g.create_node(key="ctx",
                      snap=snap_ctx,
                      params={
                          "test": 1,
                          "extra_arg": 2
                      })
    env.run_node(n, g)
    assert param_vals == [1]
示例#23
0
def test_generic_schema_resolution():
    ec = make_test_run_context()
    env = ec.env
    g = Graph(env)
    n1 = g.create_node(key="node1", pipe=pipe_generic, upstream="n0")
    # pi = n1.get_interface()
    with env.session_scope() as sess:
        im = NodeInterfaceManager(ctx=ec, sess=sess, node=n1)
        block = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema1",
            realized_schema_key="_test.TestSchema2",
        )
        sess.add(block)
        sess.flush([block])
        stream = block_as_stream(block, ec, sess)
        bi = im.get_bound_interface({"input": stream})
        assert len(bi.inputs) == 1
        assert bi.resolve_nominal_output_schema(env, sess) is TestSchema1
示例#24
0
def test_worker():
    env = make_test_env()
    g = Graph(env)
    rt = env.runtimes[0]
    ec = env.get_run_context(g, current_runtime=rt)
    with env.session_scope() as sess:
        node = g.create_node(key="node", pipe=pipe_t1_source)
        w = Worker(ec)
        dfi_mgr = NodeInterfaceManager(ec, sess, node)
        bdfi = dfi_mgr.get_bound_interface()
        r = Executable(
            node.key,
            CompiledPipe(node.pipe.key, node.pipe),
            bdfi,
        )
        run_result = w.execute(r)
        output = run_result.output_block
        assert output is None
示例#25
0
def test_generic_schema_resolution():
    ec = make_test_run_context()
    env = ec.env
    g = Graph(env)
    n1 = g.create_node(key="node1", function=function_generic, input="n0")
    # pi = n1.get_interface()
    with env.md_api.begin():
        exe = Executable(node=n1, function=n1.function, execution_context=ec)
        im = NodeInterfaceManager(exe)
        block = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema1",
            realized_schema_key="_test.TestSchema2",
        )
        env.md_api.add(block)
        env.md_api.flush([block])
        stream = block_as_stream(block, ec)
        bi = im.get_bound_interface({"input": stream})
        assert len(bi.inputs) == 1
        assert bi.resolve_nominal_output_schema(env) is TestSchema1
示例#26
0
def test_natural_schema_translation():
    # TODO
    ec = make_test_run_context()
    env = ec.env
    g = Graph(env)
    translation = {"f1": "mapped_f1"}
    n1 = g.create_node(
        key="node1", pipe=pipe_t1_to_t2, upstream="n0", schema_translation=translation
    )
    pi = n1.get_interface()
    # im = NodeInterfaceManager(ctx=ec, node=n1)
    block = DataBlockMetadata(
        nominal_schema_key="_test.TestSchema1",
        realized_schema_key="_test.TestSchema1",
    )
    with env.session_scope() as sess:
        schema_translation = get_schema_translation(
            env,
            sess,
            block.realized_schema(env, sess),
            target_schema=pi.inputs[0].schema(env, sess),
            declared_schema_translation=translation,
        )
        assert schema_translation.as_dict() == translation
class TestStreams:
    def setup(self):
        ctx = make_test_run_context()
        self.ctx = ctx
        self.env = ctx.env
        self.g = Graph(self.env)
        self.graph = self.g.get_metadata_obj()
        self.dr1t1 = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema1",
            realized_schema_key="_test.TestSchema1",
        )
        self.dr2t1 = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema1",
            realized_schema_key="_test.TestSchema1",
        )
        self.dr1t2 = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema2",
            realized_schema_key="_test.TestSchema2",
        )
        self.dr2t2 = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema2",
            realized_schema_key="_test.TestSchema2",
        )
        self.node_source = self.g.create_node(key="pipe_source",
                                              pipe=pipe_t1_source)
        self.node1 = self.g.create_node(key="pipe1",
                                        pipe=pipe_t1_sink,
                                        upstream="pipe_source")
        self.node2 = self.g.create_node(key="pipe2",
                                        pipe=pipe_t1_to_t2,
                                        upstream="pipe_source")
        self.node3 = self.g.create_node(key="pipe3",
                                        pipe=pipe_generic,
                                        upstream="pipe_source")
        self.sess = self.env._get_new_metadata_session()
        self.sess.add(self.dr1t1)
        self.sess.add(self.dr2t1)
        self.sess.add(self.dr1t2)
        self.sess.add(self.dr2t2)
        self.sess.add(self.graph)

    def teardown(self):
        self.sess.close()

    def test_stream_unprocessed_pristine(self):
        s = StreamBuilder(nodes=self.node_source)
        s = s.filter_unprocessed(self.node1)
        assert s.get_query(self.ctx, self.sess).first() is None

    def test_stream_unprocessed_eligible(self):
        dfl = PipeLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            pipe_key=self.node_source.pipe.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            pipe_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        self.sess.add_all([dfl, drl])

        s = StreamBuilder(nodes=self.node_source)
        s = s.filter_unprocessed(self.node1)
        assert s.get_query(self.ctx, self.sess).first() == self.dr1t1

    def test_stream_unprocessed_ineligible_already_input(self):
        dfl = PipeLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            pipe_key=self.node_source.pipe.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            pipe_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        dfl2 = PipeLog(
            graph_id=self.graph.hash,
            node_key=self.node1.key,
            pipe_key=self.node1.pipe.key,
            runtime_url="test",
        )
        drl2 = DataBlockLog(
            pipe_log=dfl2,
            data_block=self.dr1t1,
            direction=Direction.INPUT,
        )
        self.sess.add_all([dfl, drl, dfl2, drl2])

        s = StreamBuilder(nodes=self.node_source)
        s = s.filter_unprocessed(self.node1)
        assert s.get_query(self.ctx, self.sess).first() is None

    def test_stream_unprocessed_ineligible_already_output(self):
        """
        By default we don't input a block that has already been output by a DF, _even if that block was never input_,
        UNLESS input is a self reference (`this`). This is to prevent infinite loops.
        """
        dfl = PipeLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            pipe_key=self.node_source.pipe.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            pipe_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        dfl2 = PipeLog(
            graph_id=self.graph.hash,
            node_key=self.node1.key,
            pipe_key=self.node1.pipe.key,
            runtime_url="test",
        )
        drl2 = DataBlockLog(
            pipe_log=dfl2,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        self.sess.add_all([dfl, drl, dfl2, drl2])

        s = StreamBuilder(nodes=self.node_source)
        s1 = s.filter_unprocessed(self.node1)
        assert s1.get_query(self.ctx, self.sess).first() is None

        # But ok with self reference
        s2 = s.filter_unprocessed(self.node1, allow_cycle=True)
        assert s2.get_query(self.ctx, self.sess).first() == self.dr1t1

    def test_stream_unprocessed_eligible_schema(self):
        dfl = PipeLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            pipe_key=self.node_source.pipe.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            pipe_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        self.sess.add_all([dfl, drl])

        s = StreamBuilder(nodes=self.node_source, schema="TestSchema1")
        s = s.filter_unprocessed(self.node1)
        assert s.get_query(self.ctx, self.sess).first() == self.dr1t1

        s = StreamBuilder(nodes=self.node_source, schema="TestSchema2")
        s = s.filter_unprocessed(self.node1)
        assert s.get_query(self.ctx, self.sess).first() is None

    def test_operators(self):
        dfl = PipeLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            pipe_key=self.node_source.pipe.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            pipe_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        drl2 = DataBlockLog(
            pipe_log=dfl,
            data_block=self.dr2t1,
            direction=Direction.OUTPUT,
        )
        self.sess.add_all([dfl, drl, drl2])

        self._cnt = 0

        @operator
        def count(stream: DataBlockStream) -> DataBlockStream:
            for db in stream:
                self._cnt += 1
                yield db

        sb = StreamBuilder(nodes=self.node_source)
        expected_cnt = sb.get_query(self.ctx, self.sess).count()
        assert expected_cnt == 2
        list(count(sb).as_managed_stream(self.ctx, self.sess))
        assert self._cnt == expected_cnt

        # Test composed operators
        self._cnt = 0
        list(count(latest(sb)).as_managed_stream(self.ctx, self.sess))
        assert self._cnt == 1

        # Test kwargs
        self._cnt = 0
        list(
            count(filter(sb, function=lambda db: False)).as_managed_stream(
                self.ctx, self.sess))
        assert self._cnt == 0
示例#28
0
class TestStreams:
    def setup(self):
        ctx = make_test_run_context()
        self.ctx = ctx
        self.env = ctx.env
        self.sess = self.env.md_api.begin()
        self.sess.__enter__()
        self.g = Graph(self.env)
        self.graph = self.g.get_metadata_obj()
        self.dr1t1 = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema1",
            realized_schema_key="_test.TestSchema1",
        )
        self.dr2t1 = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema1",
            realized_schema_key="_test.TestSchema1",
        )
        self.dr1t2 = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema2",
            realized_schema_key="_test.TestSchema2",
        )
        self.dr2t2 = DataBlockMetadata(
            nominal_schema_key="_test.TestSchema2",
            realized_schema_key="_test.TestSchema2",
        )
        self.node_source = self.g.create_node(key="snap_source",
                                              snap=snap_t1_source)
        self.node1 = self.g.create_node(key="snap1",
                                        snap=snap_t1_sink,
                                        input="snap_source")
        self.node2 = self.g.create_node(key="snap2",
                                        snap=snap_t1_to_t2,
                                        input="snap_source")
        self.node3 = self.g.create_node(key="snap3",
                                        snap=snap_generic,
                                        input="snap_source")
        self.env.md_api.add(self.dr1t1)
        self.env.md_api.add(self.dr2t1)
        self.env.md_api.add(self.dr1t2)
        self.env.md_api.add(self.dr2t2)
        self.env.md_api.add(self.graph)

    def teardown(self):
        self.sess.__exit__(None, None, None)

    def test_stream_unprocessed_pristine(self):
        s = stream(nodes=self.node_source)
        s = s.filter_unprocessed(self.node1)
        assert s.get_query_result(self.env).scalar_one_or_none() is None

    def test_stream_unprocessed_eligible(self):
        dfl = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            snap_key=self.node_source.snap.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            snap_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        self.env.md_api.add_all([dfl, drl])

        s = stream(nodes=self.node_source)
        s = s.filter_unprocessed(self.node1)
        assert s.get_query_result(self.env).scalar_one_or_none() == self.dr1t1

    def test_stream_unprocessed_ineligible_already_input(self):
        dfl = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            snap_key=self.node_source.snap.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            snap_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        dfl2 = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node1.key,
            snap_key=self.node1.snap.key,
            runtime_url="test",
        )
        drl2 = DataBlockLog(
            snap_log=dfl2,
            data_block=self.dr1t1,
            direction=Direction.INPUT,
        )
        self.env.md_api.add_all([dfl, drl, dfl2, drl2])

        s = stream(nodes=self.node_source)
        s = s.filter_unprocessed(self.node1)
        assert s.get_query_result(self.env).scalar_one_or_none() is None

    def test_stream_unprocessed_ineligible_already_output(self):
        """
        By default we don't input a block that has already been output by a DF, _even if that block was never input_,
        UNLESS input is a self reference (`this`). This is to prevent infinite loops.
        """
        dfl = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            snap_key=self.node_source.snap.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            snap_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        dfl2 = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node1.key,
            snap_key=self.node1.snap.key,
            runtime_url="test",
        )
        drl2 = DataBlockLog(
            snap_log=dfl2,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        self.env.md_api.add_all([dfl, drl, dfl2, drl2])

        s = stream(nodes=self.node_source)
        s1 = s.filter_unprocessed(self.node1)
        assert s1.get_query_result(self.env).scalar_one_or_none() is None

        # But ok with self reference
        s2 = s.filter_unprocessed(self.node1, allow_cycle=True)
        assert s2.get_query_result(self.env).scalar_one_or_none() == self.dr1t1

    def test_stream_unprocessed_eligible_schema(self):
        dfl = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            snap_key=self.node_source.snap.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            snap_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        self.env.md_api.add_all([dfl, drl])

        s = stream(nodes=self.node_source, schema="TestSchema1")
        s = s.filter_unprocessed(self.node1)
        assert s.get_query_result(self.env).scalar_one_or_none() == self.dr1t1

        s = stream(nodes=self.node_source, schema="TestSchema2")
        s = s.filter_unprocessed(self.node1)
        assert s.get_query_result(self.env).scalar_one_or_none() is None

    def test_operators(self):
        dfl = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            snap_key=self.node_source.snap.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            snap_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        drl2 = DataBlockLog(
            snap_log=dfl,
            data_block=self.dr2t1,
            direction=Direction.OUTPUT,
        )
        self.env.md_api.add_all([dfl, drl, drl2])

        self._cnt = 0

        @operator
        def count(stream: DataBlockStream) -> DataBlockStream:
            for db in stream:
                self._cnt += 1
                yield db

        sb = stream(nodes=self.node_source)
        expected_cnt = sb.get_count(self.env)
        assert expected_cnt == 2
        list(count(sb).as_managed_stream(self.ctx))
        assert self._cnt == expected_cnt

        # Test composed operators
        self._cnt = 0
        list(count(latest(sb)).as_managed_stream(self.ctx))
        assert self._cnt == 1

        # Test kwargs
        self._cnt = 0
        list(
            count(filter(sb, function=lambda db: False)).as_managed_stream(
                self.ctx))
        assert self._cnt == 0

    def test_managed_stream(self):
        dfl = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node_source.key,
            snap_key=self.node_source.snap.key,
            runtime_url="test",
        )
        drl = DataBlockLog(
            snap_log=dfl,
            data_block=self.dr1t1,
            direction=Direction.OUTPUT,
        )
        dfl2 = SnapLog(
            graph_id=self.graph.hash,
            node_key=self.node1.key,
            snap_key=self.node1.snap.key,
            runtime_url="test",
        )
        drl2 = DataBlockLog(
            snap_log=dfl2,
            data_block=self.dr1t1,
            direction=Direction.INPUT,
        )
        self.env.md_api.add_all([dfl, drl, dfl2, drl2])

        s = stream(nodes=self.node_source)
        s = s.filter_unprocessed(self.node1)

        ctx = make_test_run_context()
        with ctx.env.md_api.begin():
            dbs = ManagedDataBlockStream(ctx, stream_builder=s)
            with pytest.raises(StopIteration):
                assert next(dbs) is None