def pipe(max_batch_size, input_data, device):
     pipe = Pipeline(batch_size=max_batch_size, num_threads=4, device_id=0)
     data = fn.external_source(source=input_data, cycle=False, device=device,
                               layout="FHWC")
     processed, _ = fn.element_extract(data, element_map=[0, 3])
     pipe.set_outputs(processed)
     return pipe
def element_extract_pipe(shape, layout, element_map, dev, dtype):
    min_shape = [s // 2 if s > 1 else 1 for s in shape]
    min_shape[0] = shape[0]
    min_shape = tuple(min_shape)
    input = fn.external_source(source=RandomlyShapedDataIterator(
        batch_size, min_shape=min_shape, max_shape=shape, dtype=dtype),
                               layout=layout)
    if dev == "gpu":
        input = input.gpu()
    elements = fn.element_extract(input, element_map=element_map)
    result = (input, ) + tuple(elements) if len(element_map) > 1 else (
        input, elements)
    return result
Esempio n. 3
0
def test_element_extract_cpu():
    pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=None)
    test_data_shape = [5, 10, 20, 3]

    def get_data():
        out = [
            np.random.randint(0, 255, size=test_data_shape, dtype=np.uint8)
            for _ in range(batch_size)
        ]
        return out

    data = fn.external_source(source=get_data, layout="FHWC")
    processed, _ = fn.element_extract(data, element_map=[0, 3])
    pipe.set_outputs(processed)
    pipe.build()
    for _ in range(3):
        pipe.run()
def dali_dataloader(
        tfrec_filenames,
        tfrec_idx_filenames,
        shard_id=0, num_shards=1,
        batch_size=128, num_threads=os.cpu_count(),
        image_size=224, num_workers=1, training=True):
    pipe = Pipeline(batch_size=batch_size,
                    num_threads=num_threads, device_id=0)
    with pipe:
        inputs = fn.readers.tfrecord(
            path=tfrec_filenames,
            index_path=tfrec_idx_filenames,
            random_shuffle=training,
            shard_id=shard_id,
            num_shards=num_shards,
            initial_fill=10000,
            read_ahead=True,
            pad_last_batch=True,
            prefetch_queue_depth=num_workers,
            name='Reader',
            features={
                'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
                'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64,  -1),
            })
        jpegs = inputs["image/encoded"]
        if training:
            images = fn.decoders.image_random_crop(
                jpegs,
                device="mixed",
                output_type=types.RGB,
                random_aspect_ratio=[0.8, 1.25],
                random_area=[0.1, 1.0],
                num_attempts=100)
            images = fn.resize(images,
                               device='gpu',
                               resize_x=image_size,
                               resize_y=image_size,
                               interp_type=types.INTERP_TRIANGULAR)
            mirror = fn.random.coin_flip(probability=0.5)
        else:
            images = fn.decoders.image(jpegs,
                                       device='mixed',
                                       output_type=types.RGB)
            images = fn.resize(images,
                               device='gpu',
                               size=int(image_size / 0.875),
                               mode="not_smaller",
                               interp_type=types.INTERP_TRIANGULAR)
            mirror = False

        images = fn.crop_mirror_normalize(
            images.gpu(),
            dtype=types.FLOAT,
            crop=(image_size, image_size),
            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
            std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
            mirror=mirror)
        label = inputs["image/class/label"] - 1  # 0-999
        label = fn.element_extract(label, element_map=0)  # Flatten
        label = label.gpu()
        pipe.set_outputs(images, label)

    pipe.build()
    last_batch_policy = LastBatchPolicy.DROP if training else LastBatchPolicy.PARTIAL
    loader = DALIClassificationIterator(
        pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy)
    return loader