Exemple #1
0
def check_deserialization(batch_size, num_threads, shape):
    ref_pipe = TestPipeline(batch_size=batch_size,
                            num_threads=num_threads,
                            shape=shape)
    serialized = ref_pipe.serialize()
    test_pipe = Pipeline.deserialize(serialized)
    test_utils.compare_pipelines(ref_pipe,
                                 test_pipe,
                                 batch_size=batch_size,
                                 N_iterations=3)
Exemple #2
0
def check_deserialization_from_file(batch_size, num_threads, shape):
    filename = "/tmp/dali.serialize.pipeline.test"
    ref_pipe = TestPipeline(batch_size=batch_size,
                            num_threads=num_threads,
                            shape=shape)
    ref_pipe.serialize(filename=filename)
    test_pipe = Pipeline.deserialize(filename=filename)
    test_utils.compare_pipelines(ref_pipe,
                                 test_pipe,
                                 batch_size=batch_size,
                                 N_iterations=3)
Exemple #3
0
def check_deserialization_with_params(batch_size, num_threads, shape):
    init_pipe = TestPipeline(batch_size=batch_size,
                             num_threads=num_threads,
                             shape=shape)
    serialized = init_pipe.serialize()
    ref_pipe = TestPipeline(batch_size=batch_size**2,
                            num_threads=num_threads + 1,
                            shape=shape)
    test_pipe = Pipeline.deserialize(serialized,
                                     batch_size=batch_size**2,
                                     num_threads=num_threads + 1)
    test_utils.compare_pipelines(ref_pipe,
                                 test_pipe,
                                 batch_size=batch_size**2,
                                 N_iterations=3)
Exemple #4
0
def check_deserialization_from_file_with_params(batch_size, num_threads,
                                                shape):
    filename = "/tmp/dali.serialize.pipeline.test"
    init_pipe = TestPipeline(batch_size=batch_size,
                             num_threads=num_threads,
                             shape=shape)
    init_pipe.serialize(filename=filename)
    ref_pipe = TestPipeline(batch_size=batch_size**2,
                            num_threads=num_threads + 1,
                            shape=shape)
    test_pipe = Pipeline.deserialize(filename=filename,
                                     batch_size=batch_size**2,
                                     num_threads=num_threads + 1)
    test_utils.compare_pipelines(ref_pipe,
                                 test_pipe,
                                 batch_size=batch_size**2,
                                 N_iterations=10)
Exemple #5
0
def test_incorrect_invocation_no_params():
    Pipeline.deserialize()
Exemple #6
0
def test_incorrect_invocation_mutually_exclusive_params():
    filename = "/tmp/dali.serialize.pipeline.test"
    pipe = TestPipeline(batch_size=3, num_threads=1, shape=[666])
    serialized = pipe.serialize(filename=filename)
    Pipeline.deserialize(serialized_pipeline=serialized, filename=filename)