Esempio n. 1
0
    def batchify_image_features(self, batch: Batch) -> Batch:
        """
        Return the image features as a Tensor of the correct type.

        Fill in missing feature vectors. Here, we require image features to be saved in
        `batch` as a Tensor for passing through the image encoder. This is required for
        data_parallel.
        """

        # Checks/formatting of batch.image
        bsz = self._get_batch_size(batch)
        if batch.image is None or len(batch.image) == 0:
            batch.image = [None] * bsz
        else:
            assert len(batch.image) == bsz

        # Process all image feature vectors, or add in zero vectors if missing
        processed_features_list = []
        processed_zero_features = self._process_image_features(
            torch.zeros((self.image_features_dim, )))
        for orig_features in batch.image:
            if isinstance(orig_features, torch.Tensor):
                processed_features_list.append(
                    self._process_image_features(orig_features))
            else:
                if orig_features is not None:
                    warn_once(
                        'Unsupported image feature format. Image features will be ignored!'
                    )
                processed_features_list.append(processed_zero_features)

        # Turn into batchsize x image_features_dim for DataParallel
        batch.image = torch.stack(processed_features_list)

        return batch
Esempio n. 2
0
    def batchify_image_features(self, batch: Batch) -> Batch:
        """
        Format and return the batched image features.

        Image features represented by tensors will set to the right type.
        """
        if type(batch.image) == list and any(b is not None for b in batch.image):
            images = []
            for img in batch.image:
                if isinstance(img, torch.Tensor):
                    img = self._process_image_features(img)
                images.append(img)
            batch.image = images
        else:
            images = [None] * len(batch.valid_indices)
            batch.image = images
        return batch