def _callback_args(self, idx_in_batch, epoch_idx): if not self.accepts_arg: return () if idx_in_batch is not None: arg = _types.SampleInfo(self._current_sample + idx_in_batch, idx_in_batch, self._current_iter, epoch_idx) elif self._batch_info: arg = _types.BatchInfo(self._current_iter, epoch_idx) else: arg = self._current_iter return (arg, )
def __next__(self): if self.idx_in_epoch == 0 and CallableSampleIterator.first_value is not None: result = CallableSampleIterator.first_value CallableSampleIterator.first_value = None else: # There is no notion of epochs when iterating over DALI Dataset # as the "raise" policy is not supported, so we use epoch 0 only. idx = types.SampleInfo(self.idx_in_epoch, self.idx_in_batch, self.iteration, 0) result = self.source(idx) self.idx_in_epoch += 1 self.idx_in_batch += 1 if self.idx_in_batch == batch_size: self.idx_in_batch = 0 self.iteration += 1 return sample_to_numpy(result, _tf_sample_error_msg)
def load_frames(sample_info=types.SampleInfo(0, 0, 0, 0), hint_grid=None): img = cv2.imread(os.path.join(images_dir, 'alley.png')) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if sample_info.idx_in_epoch % 2: img = cv2.resize(img, dsize=(img.shape[0] // 2, img.shape[1] // 2), interpolation=cv2.INTER_AREA) xy, ofs = get_mapping(img.shape[:2]) remap = (xy + ofs - np.array([[[0.5, 0.5]]])).astype(np.float32) warped = cv2.remap(img, remap, None, interpolation=cv2.INTER_LINEAR) result = np.array([img, warped]) if hint_grid is not None: result = [result] result.append(np.zeros(shape=result[0].shape, dtype=np.uint8)) return result
def get_sample_iterable_from_callback(source_desc: SourceDescription, batch_size): """Transform sample callback accepting one argument into an Iterable """ first = source_desc.source(types.SampleInfo(0, 0, 0, 0)) dtype, shape = _inspect_data(first, False) class CallableSampleIterator: first_value = first def __init__(self): self.idx_in_epoch = 0 self.idx_in_batch = 0 self.iteration = 0 self.source = source_desc.source def __iter__(self): self.idx_in_epoch = 0 self.idx_in_batch = 0 self.iteration = 0 return self def __next__(self): if self.idx_in_epoch == 0 and CallableSampleIterator.first_value is not None: result = CallableSampleIterator.first_value CallableSampleIterator.first_value = None else: # There is no notion of epochs when iterating over DALI Dataset # as the "raise" policy is not supported, so we use epoch 0 only. idx = types.SampleInfo(self.idx_in_epoch, self.idx_in_batch, self.iteration, 0) result = self.source(idx) self.idx_in_epoch += 1 self.idx_in_batch += 1 if self.idx_in_batch == batch_size: self.idx_in_batch = 0 self.iteration += 1 return sample_to_numpy(result, _tf_sample_error_msg) return CallableSampleIterator, dtype, shape