class CXRVisDialDataset(Dataset): def __init__(self, image_features_path, dialog_path, vocab_path=None, mode='concat', permute=False): """ A dataset class. Args: image_features (str): path to image features h5 file dialog_path (str): path to .json file with dialog data vocab_path (str, optional): path to word counts. If None, BERT vocabulary is used instead permute (bool, optional): Whether to permute dialog turns in random order. Defaults to False views (list, optional): List of views for which image vectors are extracted and concatenated """ super().__init__() if vocab_path is not None: self.vocabulary = Vocabulary(vocab_path) self.tokenizer = RegexpTokenizer('\w+') self.bert = False else: self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') self.vocabulary = self.tokenizer.vocab self.bert = True # Read dialogs and corresponding image ids self.dialog_reader = DialogReader(dialog_path, self.vocabulary, self.tokenizer, mode, permute) self.image_ids = self.dialog_reader.image_ids # Read image vectors self.image_reader = ImageFeaturesReader(image_features_path, self.image_ids) # Get all possible questions and answers self.questions = self.dialog_reader.visdial_data['data']['questions'] self.answers = self.dialog_reader.visdial_data['data']['answers'] self.all_dialogs = self.get_all_dialogs(mode) def get_all_dialogs(self, mode): """ Extract all dialog examples. Args: mode (str): If 'seq', dialog history consists of separate sequences of turns. If 'concat', dialog turns are concatenated. Returns: a list of dictionaries with dialog examples. Tokens are replaced with their ids based on vocabulary. Returned dialogs are padded. """ all_dialogs = [] print('Extracting all possible data examples...') for dialog in tqdm(self.dialog_reader.dialogs): image_id_frontal = dialog['image_id'] # image_id_frontal = dialog['image_id_frontal'] # image_id_lateral = dialog['image_id_lateral'] try: image_vector_frontal = self.image_reader.image_vectors[ image_id_frontal] # image_vector_lateral = self.image_reader.image_vectors[image_id_lateral] # image_vector = np.concatenate((image_vector_frontal, image_vector_lateral), axis=0) image_vector = image_vector_frontal except Exception: continue num_turns = len(dialog['history']) for turn in range(num_turns): if mode == 'concat': if self.bert: history_ids = self.tokenizer.convert_tokens_to_ids( dialog['padded_history'][turn]) caption_ids = self.tokenizer.convert_tokens_to_ids( dialog['padded_history'][0]) question_ids = self.tokenizer.convert_tokens_to_ids( dialog['padded_question'][turn]) else: history_ids = self.vocabulary.to_indices( dialog['padded_history'][turn]) caption_ids = self.vocabulary.to_indices( dialog['padded_history'][0]) question_ids = self.vocabulary.to_indices( dialog['padded_question'][turn]) elif mode == 'seq': if self.bert: history_ids = [ self.tokenizer.convert_tokens_to_ids(sequence) for sequence in dialog['padded_history'][turn] ] caption_ids = history_ids[0] question_ids = [ self.tokenizer.convert_tokens_to_ids(sequence) for sequence in dialog['padded_all_questions'][turn] ] else: history_ids = [ self.vocabulary.to_indices(sequence) for sequence in dialog['padded_history'][turn] ] caption_ids = history_ids[0] question_ids = [ self.vocabulary.to_indices(sequence) for sequence in dialog['padded_all_questions'][turn] ] if self.bert: option_ids = [ self.tokenizer.convert_tokens_to_ids(option) for option in dialog['padded_options'][turn] ] else: option_ids = [ self.vocabulary.to_indices(option) for option in dialog['padded_options'][turn] ] history_ids = torch.Tensor(history_ids).long() question_ids = torch.Tensor(question_ids).long() option_ids = torch.Tensor(option_ids).long() caption_ids = torch.Tensor(caption_ids).long() answer_ind = dialog['answer_ind'][turn] all_dialogs.append({ 'history_ids': history_ids, 'history': dialog['history'][turn], 'question': dialog['question'][turn], 'question_ids': question_ids, 'answer': dialog['answer'][turn], 'image': image_vector, 'options': option_ids, 'answer_ind': answer_ind, 'caption_ids': caption_ids, 'turn': turn }) return all_dialogs def __len__(self): return len(self.all_dialogs) def __getitem__(self, idx): return self.all_dialogs[idx]