def _get_images_product_texts(self, image_ids: List[int], num_products: int): """Get images and product texts of a response. Args: image_ids (List[int]): Image ids. num_products (int): Number of images (max images). Returns: num_products (int): Number of products (exclude padding). images: Images (num_products, 3, image_size, image_size). product_texts: Product texts (num_products, product_text_max_len). product_text_lengths: Product text lengths (num_products, ). """ images = [] product_texts = [] product_text_lengths = [] for img_id in image_ids: if img_id == 0: break image_name = self.image_paths[img_id] image_path = join(DatasetConfig.image_data_directory, image_name) product_path = get_product_path(image_name) # Image. raw_image = Image.open(image_path).convert("RGB") image = DatasetConfig.transform(raw_image) images.append(image) # Text. text = Dataset._get_product_text(product_path) text = [ self.dialog_vocab.get(word, UNK_ID) for word in word_tokenize(text) ] text, text_len = pad_or_clip_text( text, DatasetConfig.product_text_max_len) product_texts.append(text) product_text_lengths.append(text_len) # Padding. num_pads = (num_products - len(images)) images.extend([self.EMPTY_IMAGE] * num_pads) product_texts.extend([self.EMPTY_PRODUCT_TEXT] * num_pads) product_text_lengths.extend([1] * num_pads) # To tensors. num_products = len(images) images = torch.stack(images) product_texts = torch.stack(list(map(torch.tensor, product_texts))) product_text_lengths = torch.tensor(product_text_lengths) return num_products, images, product_texts, product_text_lengths
def _get_context_dialog(self, dialog: TidyDialog): """Get context dialog. Note: The last utterance of the context dialog is system response. Args: dialog (TidyDialog): Dialog. Returns: texts: Texts (dialog_context_size + 1, dialog_text_max_len). text_lengths: Text lengths (dialog_context_size + 1, ). images: Images (dialog_context_size + 1, pos_images_max_num, 3, image_size, image_size). utter_type (int): The type of the last user utterance. """ # Text. text_list: List[List[int]] = [utter.text for utter in dialog] text_length_list: List[int] = [utter.text_len for utter in dialog] # Text tensors. texts = torch.stack(tuple([torch.tensor(text) for text in text_list])) # (dialog_context_size + 1, dialog_text_max_len) text_lengths = torch.tensor(text_length_list) # (dialog_context_size + 1, ) # Image. image_list = [[] for _ in range(DatasetConfig.dialog_context_size + 1)] for idx, utter in enumerate(dialog): for img_id in utter.pos_images: path = self.image_paths[img_id] if path: path = join(DatasetConfig.image_data_directory, path) else: path = '' if path and isfile(path): try: raw_image = Image.open(path).convert("RGB") image = DatasetConfig.transform(raw_image) image_list[idx].append(image) except OSError: image_list[idx].append(Dataset.EMPTY_IMAGE) else: image_list[idx].append(Dataset.EMPTY_IMAGE) images = torch.stack(list(map(torch.stack, image_list))) # (dialog_context_size + 1, pos_images_max_num, # 3, image_size, image_size) # Utterance type. utter_type = dialog[-2].utter_type return texts, text_lengths, images, utter_type