def __call__( self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]] ) -> List[torch.Tensor]: """Prepare document data for model forwarding Args: x: list of images (np.array) or tensors (already resized and batched) Returns: list of page batches """ # Input type check if isinstance(x, (np.ndarray, torch.Tensor)): if x.ndim != 4: raise AssertionError("expected 4D Tensor") if isinstance(x, np.ndarray): if x.dtype not in (np.uint8, np.float32): raise TypeError("unsupported data type for numpy.ndarray") x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2) elif x.dtype not in (torch.uint8, torch.float16, torch.float32): raise TypeError("unsupported data type for torch.Tensor") # Resizing if x.shape[-2] != self.resize.size[0] or x.shape[ -1] != self.resize.size[1]: x = F.resize(x, self.resize.size, interpolation=self.resize.interpolation) # Data type if x.dtype == torch.uint8: # type: ignore[union-attr] x = x.to(dtype=torch.float32).div(255).clip( 0, 1) # type: ignore[union-attr] else: x = x.to(dtype=torch.float32) # type: ignore[union-attr] batches = [x] elif isinstance(x, list) and all( isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x): # Sample transform (to tensor, resize) samples = list(multithread_exec(self.sample_transforms, x)) # Batching batches = self.batch_inputs(samples) else: raise TypeError(f"invalid input type: {type(x)}") # Batch transforms (normalize) batches = list(multithread_exec(self.normalize, batches)) return batches
def __call__( self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]] ) -> List[tf.Tensor]: """Prepare document data for model forwarding Args: x: list of images (np.array) or tensors (already resized and batched) Returns: list of page batches """ # Input type check if isinstance(x, (np.ndarray, tf.Tensor)): if x.ndim != 4: raise AssertionError("expected 4D Tensor") if isinstance(x, np.ndarray): if x.dtype not in (np.uint8, np.float32): raise TypeError("unsupported data type for numpy.ndarray") x = tf.convert_to_tensor(x) elif x.dtype not in (tf.uint8, tf.float16, tf.float32): raise TypeError("unsupported data type for torch.Tensor") # Data type & 255 division if x.dtype == tf.uint8: x = tf.image.convert_image_dtype(x, dtype=tf.float32) # Resizing if x.shape[1] != self.resize.output_size[0] or x.shape[ 2] != self.resize.output_size[1]: x = tf.image.resize(x, self.resize.output_size, method=self.resize.method) batches = [x] elif isinstance(x, list) and all( isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x): # Sample transform (to tensor, resize) samples = multithread_exec(self.sample_transforms, x) # Batching batches = self.batch_inputs(samples) # type: ignore[arg-type] else: raise TypeError(f"invalid input type: {type(x)}") # Batch transforms (normalize) batches = multithread_exec(self.normalize, batches) # type: ignore[assignment] return batches
def __next__(self): if self._num_yielded < self.num_batches: # Get next indices idx = self._num_yielded * self.batch_size indices = self.indices[idx: min(len(self.dataset), idx + self.batch_size)] samples = multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers) batch_data = self.collate_fn(samples) self._num_yielded += 1 return batch_data else: raise StopIteration
def test_multithread_exec(input_seq, func, output_seq): assert list(multithread_exec(func, input_seq)) == output_seq assert list(multithread_exec(func, input_seq, 0)) == output_seq