Ejemplo n.º 1
0
    def _get_images_product_texts(self, image_ids: List[int],
                                  num_products: int):
        """Get images and product texts of a response.

        Args:
            image_ids (List[int]): Image ids.
            num_products (int): Number of images (max images).

        Returns:
            num_products (int): Number of products (exclude padding).
            images: Images (num_products, 3, image_size, image_size).
            product_texts: Product texts (num_products, product_text_max_len).
            product_text_lengths: Product text lengths (num_products, ).

        """
        images = []
        product_texts = []
        product_text_lengths = []
        for img_id in image_ids:
            if img_id == 0:
                break
            image_name = self.image_paths[img_id]
            image_path = join(DatasetConfig.image_data_directory, image_name)
            product_path = get_product_path(image_name)

            # Image.
            raw_image = Image.open(image_path).convert("RGB")
            image = DatasetConfig.transform(raw_image)
            images.append(image)

            # Text.
            text = Dataset._get_product_text(product_path)
            text = [
                self.dialog_vocab.get(word, UNK_ID)
                for word in word_tokenize(text)
            ]
            text, text_len = pad_or_clip_text(
                text, DatasetConfig.product_text_max_len)
            product_texts.append(text)
            product_text_lengths.append(text_len)

        # Padding.
        num_pads = (num_products - len(images))
        images.extend([self.EMPTY_IMAGE] * num_pads)
        product_texts.extend([self.EMPTY_PRODUCT_TEXT] * num_pads)
        product_text_lengths.extend([1] * num_pads)

        # To tensors.
        num_products = len(images)
        images = torch.stack(images)
        product_texts = torch.stack(list(map(torch.tensor, product_texts)))
        product_text_lengths = torch.tensor(product_text_lengths)
        return num_products, images, product_texts, product_text_lengths
Ejemplo n.º 2
0
    def _get_context_dialog(self, dialog: TidyDialog):
        """Get context dialog.

        Note: The last utterance of the context dialog is system response.

        Args:
            dialog (TidyDialog): Dialog.

        Returns:
            texts: Texts (dialog_context_size + 1, dialog_text_max_len).
            text_lengths: Text lengths (dialog_context_size + 1, ).
            images: Images (dialog_context_size + 1, pos_images_max_num, 3,
                           image_size, image_size).
            utter_type (int): The type of the last user utterance.

        """
        # Text.
        text_list: List[List[int]] = [utter.text for utter in dialog]
        text_length_list: List[int] = [utter.text_len for utter in dialog]

        # Text tensors.
        texts = torch.stack(tuple([torch.tensor(text) for text in text_list]))
        # (dialog_context_size + 1, dialog_text_max_len)
        text_lengths = torch.tensor(text_length_list)
        # (dialog_context_size + 1, )

        # Image.
        image_list = [[] for _ in range(DatasetConfig.dialog_context_size + 1)]

        for idx, utter in enumerate(dialog):
            for img_id in utter.pos_images:
                path = self.image_paths[img_id]
                if path:
                    path = join(DatasetConfig.image_data_directory, path)
                else:
                    path = ''
                if path and isfile(path):
                    try:
                        raw_image = Image.open(path).convert("RGB")
                        image = DatasetConfig.transform(raw_image)
                        image_list[idx].append(image)
                    except OSError:
                        image_list[idx].append(Dataset.EMPTY_IMAGE)
                else:
                    image_list[idx].append(Dataset.EMPTY_IMAGE)

        images = torch.stack(list(map(torch.stack, image_list)))
        # (dialog_context_size + 1, pos_images_max_num,
        # 3, image_size, image_size)

        # Utterance type.
        utter_type = dialog[-2].utter_type
        return texts, text_lengths, images, utter_type