def run_tf_dataset_no_copy(max_shape, dtype, dataset_dev, es_dev, no_copy):
    run_tf_dataset_graph(dataset_dev,
                         get_pipeline_desc=external_source_tester(
                             max_shape, dtype,
                             RandomSampleIterator(max_shape, dtype(0)), es_dev,
                             no_copy),
                         to_dataset=external_source_converter_with_callback(
                             RandomSampleIterator, max_shape, dtype))
예제 #2
0
def run_tf_dataset_with_stop_iter(dev, max_shape, dtype, stop_samples):
    it1 = RandomSampleIterator(max_shape, dtype(0), start=0, stop=stop_samples)
    get_pipeline_desc = external_source_tester(max_shape, dtype, it1)
    to_dataset = external_source_converter_with_callback(
        RandomSampleIterator, max_shape, dtype, 0, stop_samples)
    run_tf_dataset_eager_mode(dev,
                              to_stop_iter=True,
                              get_pipeline_desc=get_pipeline_desc,
                              to_dataset=to_dataset)
def run_tf_dataset_with_constant_input(dev, shape, value, dtype, batch):
    tensor = np.full(shape, value, dtype)
    run_tf_dataset_graph(dev,
                         get_pipeline_desc=external_source_tester(
                             shape,
                             dtype,
                             FixedSampleIterator(tensor),
                             batch=batch),
                         to_dataset=external_source_converter_with_fixed_value(
                             shape, dtype, tensor, batch))
def run_tf_dataset_with_stop_iter(dev, max_shape, dtype, stop_samples):
    run_tf_dataset_graph(dev,
                         to_stop_iter=True,
                         get_pipeline_desc=external_source_tester(
                             max_shape, dtype,
                             RandomSampleIterator(max_shape,
                                                  dtype(0),
                                                  start=0,
                                                  stop=stop_samples)),
                         to_dataset=external_source_converter_with_callback(
                             RandomSampleIterator, max_shape, dtype, 0,
                             stop_samples))
def run_tf_dataset_with_random_input(dev, max_shape, dtype, batch):
    min_shape = get_min_shape_helper(batch, max_shape)
    iterator = RandomSampleIterator(max_shape, dtype(0), min_shape=min_shape)
    run_tf_dataset_graph(dev,
                         get_pipeline_desc=external_source_tester(max_shape,
                                                                  dtype,
                                                                  iterator,
                                                                  batch=batch),
                         to_dataset=external_source_converter_with_callback(
                             RandomSampleIterator,
                             max_shape,
                             dtype,
                             0,
                             1e10,
                             min_shape,
                             batch=batch))
예제 #6
0
def run_tf_dataset_with_random_input(dev, max_shape, dtype, batch="dataset"):
    min_shape = get_min_shape_helper(batch, max_shape)
    it = RandomSampleIterator(max_shape, dtype(0), min_shape=min_shape)
    get_pipeline_desc = external_source_tester(max_shape,
                                               dtype,
                                               it,
                                               batch=batch)
    to_dataset = external_source_converter_with_callback(RandomSampleIterator,
                                                         max_shape,
                                                         dtype,
                                                         0,
                                                         1e10,
                                                         min_shape,
                                                         batch=batch)
    run_tf_dataset_eager_mode(dev,
                              get_pipeline_desc=get_pipeline_desc,
                              to_dataset=to_dataset)