示例#1
0
    def __call__(
        self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor,
                                                            np.ndarray]]]
    ) -> List[torch.Tensor]:
        """Prepare document data for model forwarding

        Args:
            x: list of images (np.array) or tensors (already resized and batched)
        Returns:
            list of page batches
        """

        # Input type check
        if isinstance(x, (np.ndarray, torch.Tensor)):
            if x.ndim != 4:
                raise AssertionError("expected 4D Tensor")
            if isinstance(x, np.ndarray):
                if x.dtype not in (np.uint8, np.float32):
                    raise TypeError("unsupported data type for numpy.ndarray")
                x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
            elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
                raise TypeError("unsupported data type for torch.Tensor")
            # Resizing
            if x.shape[-2] != self.resize.size[0] or x.shape[
                    -1] != self.resize.size[1]:
                x = F.resize(x,
                             self.resize.size,
                             interpolation=self.resize.interpolation)
            # Data type
            if x.dtype == torch.uint8:  # type: ignore[union-attr]
                x = x.to(dtype=torch.float32).div(255).clip(
                    0, 1)  # type: ignore[union-attr]
            else:
                x = x.to(dtype=torch.float32)  # type: ignore[union-attr]
            batches = [x]

        elif isinstance(x, list) and all(
                isinstance(sample, (np.ndarray, torch.Tensor))
                for sample in x):
            # Sample transform (to tensor, resize)
            samples = list(multithread_exec(self.sample_transforms, x))
            # Batching
            batches = self.batch_inputs(samples)
        else:
            raise TypeError(f"invalid input type: {type(x)}")

        # Batch transforms (normalize)
        batches = list(multithread_exec(self.normalize, batches))

        return batches
示例#2
0
    def __call__(
        self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor,
                                                         np.ndarray]]]
    ) -> List[tf.Tensor]:
        """Prepare document data for model forwarding

        Args:
            x: list of images (np.array) or tensors (already resized and batched)
        Returns:
            list of page batches
        """

        # Input type check
        if isinstance(x, (np.ndarray, tf.Tensor)):
            if x.ndim != 4:
                raise AssertionError("expected 4D Tensor")
            if isinstance(x, np.ndarray):
                if x.dtype not in (np.uint8, np.float32):
                    raise TypeError("unsupported data type for numpy.ndarray")
                x = tf.convert_to_tensor(x)
            elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
                raise TypeError("unsupported data type for torch.Tensor")

            # Data type & 255 division
            if x.dtype == tf.uint8:
                x = tf.image.convert_image_dtype(x, dtype=tf.float32)
            # Resizing
            if x.shape[1] != self.resize.output_size[0] or x.shape[
                    2] != self.resize.output_size[1]:
                x = tf.image.resize(x,
                                    self.resize.output_size,
                                    method=self.resize.method)

            batches = [x]

        elif isinstance(x, list) and all(
                isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
            # Sample transform (to tensor, resize)
            samples = multithread_exec(self.sample_transforms, x)
            # Batching
            batches = self.batch_inputs(samples)  # type: ignore[arg-type]
        else:
            raise TypeError(f"invalid input type: {type(x)}")

        # Batch transforms (normalize)
        batches = multithread_exec(self.normalize,
                                   batches)  # type: ignore[assignment]

        return batches
示例#3
0
    def __next__(self):
        if self._num_yielded < self.num_batches:
            # Get next indices
            idx = self._num_yielded * self.batch_size
            indices = self.indices[idx: min(len(self.dataset), idx + self.batch_size)]

            samples = multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers)

            batch_data = self.collate_fn(samples)

            self._num_yielded += 1
            return batch_data
        else:
            raise StopIteration
示例#4
0
def test_multithread_exec(input_seq, func, output_seq):
    assert list(multithread_exec(func, input_seq)) == output_seq
    assert list(multithread_exec(func, input_seq, 0)) == output_seq