示例#1
0
 def __init__(self, batch_size, num_threads, device_id, seed, image_dir,
              function):
     super(TorchPythonFunctionPipeline,
           self).__init__(batch_size, num_threads, device_id, seed,
                          image_dir)
     self.torch_function = dalitorch.TorchPythonFunction(function=function,
                                                         num_outputs=2)
示例#2
0
 def __init__(self, function, device, bp=False, batch_size=BATCH_SIZE, num_threads=NUM_WORKERS,
              device_id=DEVICE_ID, image_dir=images_dir):
     super(TorchPythonFunctionPipeline, self).__init__(batch_size, num_threads, device_id,
                                                       image_dir)
     self.device = device
     self.torch_function = dalitorch.TorchPythonFunction(function=function, num_outputs=2,
                                                         device=device,
                                                         batch_processing=bp)
示例#3
0
    def __init__(self,
                 tfrecords,
                 batch_size,
                 target_size,
                 preproc_param,
                 num_threads,
                 num_shards,
                 device_ids,
                 training=False):
        Pipeline.__init__(self,
                          batch_size=batch_size,
                          num_threads=num_threads,
                          device_id=device_ids,
                          prefetch_queue_depth=num_threads,
                          seed=42,
                          exec_async=False,
                          exec_pipelined=False)
        DaliPipeline.__init__(self,
                              target_size=target_size,
                              preproc_param=preproc_param,
                              training=training)

        tfrecords_idx = [tfrecord + "_idx" for tfrecord in tfrecords]
        for tfrecord, tfrecord_idx in zip(tfrecords, tfrecords_idx):
            if os.path.exists(tfrecord_idx):
                continue
            call(["tfrecord2idx", tfrecord, tfrecord + "_idx"])
        self.length = sum([len(open(f).readlines()) for f in tfrecords_idx])

        self.input = ops.TFRecordReader(
            path=tfrecords,
            index_path=tfrecords_idx,
            features={
                'image/height':
                tfrec.FixedLenFeature([1], tfrec.int64, -1),
                'image/width':
                tfrec.FixedLenFeature([1], tfrec.int64, -1),
                'image/encoded':
                tfrec.FixedLenFeature((), tfrec.string, ""),
                'image/format':
                tfrec.FixedLenFeature((), tfrec.string, ""),
                'image/object/bbox/xmin':
                tfrec.VarLenFeature(tfrec.float32, 0.0),
                'image/object/bbox/ymin':
                tfrec.VarLenFeature(tfrec.float32, 0.0),
                'image/object/bbox/xmax':
                tfrec.VarLenFeature(tfrec.float32, 0.0),
                'image/object/bbox/ymax':
                tfrec.VarLenFeature(tfrec.float32, 0.0),
                'image/object/class/text':
                tfrec.FixedLenFeature([], tfrec.string, ''),
                'image/object/class/label':
                tfrec.VarLenFeature(tfrec.int64, -1)
            },
            num_shards=num_shards,
            random_shuffle=training)
        self.training = training
        self.cat = dalitorch.TorchPythonFunction(
            function=lambda l, t, r, b: torch.cat([l, t, r, b]).view(
                4, -1).permute(1, 0))  #[l*w,t*h,r*w,b*h], [l,t,r,b]
        self.cast = ops.Cast(dtype=types.DALIDataType.INT32)