コード例 #1
0
 def _callback_args(self, idx_in_batch, epoch_idx):
     if not self.accepts_arg:
         return ()
     if idx_in_batch is not None:
         arg = _types.SampleInfo(self._current_sample + idx_in_batch, idx_in_batch,
                                 self._current_iter, epoch_idx)
     elif self._batch_info:
         arg = _types.BatchInfo(self._current_iter, epoch_idx)
     else:
         arg = self._current_iter
     return (arg, )
コード例 #2
0
 def __next__(self):
     if self.idx_in_epoch == 0 and CallableSampleIterator.first_value is not None:
         result = CallableSampleIterator.first_value
         CallableSampleIterator.first_value = None
     else:
         # There is no notion of epochs when iterating over DALI Dataset
         # as the "raise" policy is not supported, so we use epoch 0 only.
         idx = types.SampleInfo(self.idx_in_epoch, self.idx_in_batch,
                                self.iteration, 0)
         result = self.source(idx)
     self.idx_in_epoch += 1
     self.idx_in_batch += 1
     if self.idx_in_batch == batch_size:
         self.idx_in_batch = 0
         self.iteration += 1
     return sample_to_numpy(result, _tf_sample_error_msg)
コード例 #3
0
def load_frames(sample_info=types.SampleInfo(0, 0, 0, 0), hint_grid=None):
    img = cv2.imread(os.path.join(images_dir, 'alley.png'))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if sample_info.idx_in_epoch % 2:
        img = cv2.resize(img,
                         dsize=(img.shape[0] // 2, img.shape[1] // 2),
                         interpolation=cv2.INTER_AREA)

    xy, ofs = get_mapping(img.shape[:2])
    remap = (xy + ofs - np.array([[[0.5, 0.5]]])).astype(np.float32)

    warped = cv2.remap(img, remap, None, interpolation=cv2.INTER_LINEAR)
    result = np.array([img, warped])

    if hint_grid is not None:
        result = [result]
        result.append(np.zeros(shape=result[0].shape, dtype=np.uint8))
    return result
コード例 #4
0
def get_sample_iterable_from_callback(source_desc: SourceDescription,
                                      batch_size):
    """Transform sample callback accepting one argument into an Iterable
    """
    first = source_desc.source(types.SampleInfo(0, 0, 0, 0))
    dtype, shape = _inspect_data(first, False)

    class CallableSampleIterator:
        first_value = first

        def __init__(self):
            self.idx_in_epoch = 0
            self.idx_in_batch = 0
            self.iteration = 0
            self.source = source_desc.source

        def __iter__(self):
            self.idx_in_epoch = 0
            self.idx_in_batch = 0
            self.iteration = 0
            return self

        def __next__(self):
            if self.idx_in_epoch == 0 and CallableSampleIterator.first_value is not None:
                result = CallableSampleIterator.first_value
                CallableSampleIterator.first_value = None
            else:
                # There is no notion of epochs when iterating over DALI Dataset
                # as the "raise" policy is not supported, so we use epoch 0 only.
                idx = types.SampleInfo(self.idx_in_epoch, self.idx_in_batch,
                                       self.iteration, 0)
                result = self.source(idx)
            self.idx_in_epoch += 1
            self.idx_in_batch += 1
            if self.idx_in_batch == batch_size:
                self.idx_in_batch = 0
                self.iteration += 1
            return sample_to_numpy(result, _tf_sample_error_msg)

    return CallableSampleIterator, dtype, shape