def _test_scalar(device, as_tensors):
    """Test propagation of scalars from external source"""
    batch_size = 4
    src_pipe = Pipeline(batch_size, 1, 0)
    src_ext = fn.external_source(
        source=lambda i:
        [np.float32(i * 10 + i + 1) for i in range(batch_size)],
        device=device)
    src_pipe.set_outputs(src_ext)

    src_pipe.build()
    dst_pipe = Pipeline(batch_size,
                        1,
                        0,
                        exec_async=False,
                        exec_pipelined=False)
    dst_pipe.set_outputs(fn.external_source(name="ext", device=device))
    dst_pipe.build()

    for iter in range(3):
        src = src_pipe.run()
        data = src[0]
        if as_tensors:
            data = [data[i] for i in range(len(data))]
        dst_pipe.feed_input("ext", data)
        dst = dst_pipe.run()
        check_batch(src[0], dst[0], batch_size, 0, 0, "")
Пример #2
0
def test_external_source_with_serialized_pipe():
    @pipeline_def
    def serialized_pipe():
        return fn.external_source(name="es")

    pipe = serialized_pipe(batch_size=10, num_threads=3, device_id=0)
    serialized_str = pipe.serialize()
    deserialized_pipe = Pipeline(10, 4, 0)
    deserialized_pipe.deserialize_and_build(serialized_str)
    deserialized_pipe.feed_input("es", np.zeros([10, 10]))
def _test_feed_input(device):
    src_pipe, batch_size = build_src_pipe(device)

    dst_pipe = Pipeline(batch_size,
                        1,
                        0,
                        exec_async=False,
                        exec_pipelined=False)
    dst_pipe.set_outputs(fn.external_source(name="ext", device=device))
    dst_pipe.build()
    for iter in range(3):
        out1 = src_pipe.run()
        dst_pipe.feed_input("ext", out1[0])
        out2 = dst_pipe.run()
        check_batch(out2[0], out1[0], batch_size, 0, 0, "XY")
def test_layout_changing():
    src_pipe = Pipeline(1, 1, 0)
    src_pipe.set_outputs(fn.external_source(name="input"))
    src_pipe.build()
    src_pipe.feed_input("input", [np.zeros((1))], layout="W")
    src_pipe.feed_input("input", [np.zeros((1))], layout="H")