Beispiel #1
0
class CXRVisDialDataset(Dataset):
    def __init__(self,
                 image_features_path,
                 dialog_path,
                 vocab_path=None,
                 mode='concat',
                 permute=False):
        """
        A dataset class.

        Args:
            image_features (str): path to image features h5 file
            dialog_path (str): path to .json file with dialog data
            vocab_path (str, optional): path to word counts. If None, BERT vocabulary is used instead
            permute (bool, optional): Whether to permute dialog turns in random order. Defaults to False
            views (list, optional): List of views for which image vectors are extracted and concatenated

        """
        super().__init__()
        if vocab_path is not None:
            self.vocabulary = Vocabulary(vocab_path)
            self.tokenizer = RegexpTokenizer('\w+')
            self.bert = False
        else:
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
            self.vocabulary = self.tokenizer.vocab
            self.bert = True

        # Read dialogs and corresponding image ids
        self.dialog_reader = DialogReader(dialog_path, self.vocabulary,
                                          self.tokenizer, mode, permute)
        self.image_ids = self.dialog_reader.image_ids

        # Read image vectors
        self.image_reader = ImageFeaturesReader(image_features_path,
                                                self.image_ids)

        # Get all possible questions and answers
        self.questions = self.dialog_reader.visdial_data['data']['questions']
        self.answers = self.dialog_reader.visdial_data['data']['answers']

        self.all_dialogs = self.get_all_dialogs(mode)

    def get_all_dialogs(self, mode):
        """
        Extract all dialog examples.
        Args:
            mode (str): If 'seq', dialog history consists of separate sequences of turns.
                If 'concat', dialog turns are concatenated.

        Returns:
            a list of dictionaries with dialog examples. Tokens are replaced with their ids based on vocabulary.
                Returned dialogs are padded.

        """
        all_dialogs = []
        print('Extracting all possible data examples...')
        for dialog in tqdm(self.dialog_reader.dialogs):
            image_id_frontal = dialog['image_id']
            # image_id_frontal = dialog['image_id_frontal']
            # image_id_lateral = dialog['image_id_lateral']
            try:
                image_vector_frontal = self.image_reader.image_vectors[
                    image_id_frontal]
                # image_vector_lateral = self.image_reader.image_vectors[image_id_lateral]
                # image_vector = np.concatenate((image_vector_frontal, image_vector_lateral), axis=0)
                image_vector = image_vector_frontal
            except Exception:
                continue
            num_turns = len(dialog['history'])
            for turn in range(num_turns):
                if mode == 'concat':
                    if self.bert:
                        history_ids = self.tokenizer.convert_tokens_to_ids(
                            dialog['padded_history'][turn])
                        caption_ids = self.tokenizer.convert_tokens_to_ids(
                            dialog['padded_history'][0])
                        question_ids = self.tokenizer.convert_tokens_to_ids(
                            dialog['padded_question'][turn])
                    else:
                        history_ids = self.vocabulary.to_indices(
                            dialog['padded_history'][turn])
                        caption_ids = self.vocabulary.to_indices(
                            dialog['padded_history'][0])
                        question_ids = self.vocabulary.to_indices(
                            dialog['padded_question'][turn])
                elif mode == 'seq':
                    if self.bert:
                        history_ids = [
                            self.tokenizer.convert_tokens_to_ids(sequence)
                            for sequence in dialog['padded_history'][turn]
                        ]
                        caption_ids = history_ids[0]
                        question_ids = [
                            self.tokenizer.convert_tokens_to_ids(sequence) for
                            sequence in dialog['padded_all_questions'][turn]
                        ]
                    else:
                        history_ids = [
                            self.vocabulary.to_indices(sequence)
                            for sequence in dialog['padded_history'][turn]
                        ]
                        caption_ids = history_ids[0]
                        question_ids = [
                            self.vocabulary.to_indices(sequence) for sequence
                            in dialog['padded_all_questions'][turn]
                        ]
                if self.bert:
                    option_ids = [
                        self.tokenizer.convert_tokens_to_ids(option)
                        for option in dialog['padded_options'][turn]
                    ]
                else:
                    option_ids = [
                        self.vocabulary.to_indices(option)
                        for option in dialog['padded_options'][turn]
                    ]

                history_ids = torch.Tensor(history_ids).long()
                question_ids = torch.Tensor(question_ids).long()
                option_ids = torch.Tensor(option_ids).long()
                caption_ids = torch.Tensor(caption_ids).long()
                answer_ind = dialog['answer_ind'][turn]

                all_dialogs.append({
                    'history_ids': history_ids,
                    'history': dialog['history'][turn],
                    'question': dialog['question'][turn],
                    'question_ids': question_ids,
                    'answer': dialog['answer'][turn],
                    'image': image_vector,
                    'options': option_ids,
                    'answer_ind': answer_ind,
                    'caption_ids': caption_ids,
                    'turn': turn
                })

        return all_dialogs

    def __len__(self):
        return len(self.all_dialogs)

    def __getitem__(self, idx):
        return self.all_dialogs[idx]