Beispiel #1
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'), )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        relation_dataset = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.relations'))

        dataset = SemEval2010Task8Dataset(
            annotation_text=annotated_text,
            relation_dataset=relation_dataset,
            dictionary=self.dictionary,
            seed=self.seed,
        )
        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), ['text'],
                                      ['annotation'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #2
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'), )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = TripletDataset(
            annotated_text=annotated_text,
            graph=graph,
            k_negative=self.args.k_negative,
            n_entities=len(self.entity_dictionary),
            seed=self.args.seed,
            dictionary=self.dictionary,
            same_replace_heads_for_all_negatives=self.args.arch.startswith(
                'encoder_dual'),
            negative_split_probs=self.args.negative_split_probs or [1, 0, 0],
            use_sentence_negatives=self.args.use_sentence_negatives,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), 'text')

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #3
0
    def load_dataset(self, split, prune_type=None, prune_param=None, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'),
        )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'),
        )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        relation_dataset = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.relations')
        )

        dataset = FewRelDataset(
            annotation_text=annotated_text,
            relation_dataset=relation_dataset,
            dictionary=self.dictionary,
            n_way=self.args.n_way,
            n_shot=self.args.n_shot,
            seed=self.seed,
        )

        if prune_type == 'n_train_relations':
            assert prune_param is not None
            if prune_param < 64:
                dataset.prune_by_num_relations(prune_param)
        elif prune_type == 'n_train_examples_per_relation':
            assert prune_param is not None
            if prune_param < 700:
                dataset.prune_by_num_examples_per_relation(prune_param)

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), ['text', 'exemplars'], ['annotation', 'exemplars_annotation'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #4
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        edges = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'), )

        dataset = ETPRelationDataset(
            annotated_text=annotated_text,
            edges=edges,
            dictionary=self.dictionary,
            n_entities=len(self.entity_dictionary),
            total_negatives=self.args.total_negatives,
            mask_negative_prob=self.args.mask_negative_prob,
            max_positions=self.args.max_positions,
            num_workers=self.args.num_workers,
            seed=self.args.seed,
        )

        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), 'text',
                                      ['mask_annotation', 'all_annotations'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #5
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'), )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GNNEvalDataset(
            annotated_text=annotated_text,
            graph=graph,
            dictionary=self.dictionary,
            max_positions=self.args.max_positions,
            num_workers=self.args.num_workers,
            seed=self.args.seed,
        )

        dataset = PrependTokenDataset(dataset,
                                      self.dictionary.bos(),
                                      keys=['target', 'support'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #6
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'), )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        relation_dataset = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.relations'))

        dataset = TACREDDataset(
            annotation_text=annotated_text,
            relation_dataset=relation_dataset,
            dictionary=self.dictionary,
            seed=self.seed,
        )
        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), ['text'])
        dataset.annotated_text = annotated_text
        dataset.relation_dataset = relation_dataset

        probing_dataset = TACREDProbingDataset(
            tacred_dataset=dataset,
            n_rules=self.args.n_rules,
            n_texts=self.args.n_texts,
            n_strong_negs=self.args.n_strong_negs,
            dictionary=self.dictionary,
            seed=self.seed)

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            probing_dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = probing_dataset
Beispiel #7
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'),
        )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy')
        )
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'),
        )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'),
            )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'),
            )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy')
            )
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'),
            )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GraphDistanceDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            seed=self.args.seed,
            dictionary=self.dictionary,
            class_probabilities=self.args.class_probabilities,
            n_tries_entity=self.args.n_tries_entity
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(
            dataset,
            self.dictionary.bos(),
            ['textA', 'textB']
        )

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #8
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'), )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'), )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy'))
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'), )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if (self.args.strong_negatives and self.args.strong_negative_type
                == 'similarity') or self.args.similar_positives:
            similar_entities = MMapNumpyArray(
                os.path.join(self.args.data_path, 'entity.candidates.idx.npy'))
            similarity_scores = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.candidates.scores.npy'))
        else:
            similar_entities = None
            similarity_scores = None

        dataset = PMTBDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            similar_entities=similar_entities,
            similarity_scores=similarity_scores,
            seed=self.args.seed,
            dictionary=self.dictionary,
            entity_dictionary=self.entity_dictionary,
            k_weak_negs=self.args.k_weak_negs,
            n_tries_entity=self.args.n_tries_entity,
            strong_negatives=self.args.strong_negatives,
            strong_negative_type=self.args.strong_negative_type,
            negative_temperature=getattr(self.args, 'negative_temperature',
                                         None),
            replace_tail=self.args.replace_tail,
            mutual_positives=self.args.mutual_positives,
            similar_positives=self.args.similar_positives,
            positive_temperature=getattr(self.args, 'positive_temperature',
                                         None),
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(),
                                      ['textA', 'textB'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #9
0
    def load_dataset(self,
                     split,
                     prune_type=None,
                     prune_param=None,
                     epoch=0,
                     combine=False,
                     **kwargs):

        questions = safe_load_indexed_dataset(
            os.path.join(self.args.qa_data_path,
                         split + '.questions_entities'), )

        answers = MMapNumpyArray(
            os.path.join(self.args.qa_data_path,
                         split + '.answer_entities.npy'), )

        with open(
                os.path.join(self.args.qa_data_path,
                             split + '.processed_annotations.json')) as f:
            annotations = json.load(f)

        dataset = TriviaQADataset(questions, answers, annotations)

        task_framing = self.args.task_framing

        if task_framing == 'predict_mask':
            edges = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
            dataset = ETPDownstreamDataset(
                dataset=dataset,
                edges=edges,
                dictionary=self.dictionary,
                n_entities=len(self.entity_dictionary),
                seed=self.args.seed,
                split=split,
            )
            dataset = PrependTokenDataset(dataset, self.dictionary.bos(),
                                          ['text'], ['annotation'])

        elif task_framing == 'predict_mask_relation':
            edges = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
            dataset = ETPRelationDownstreamDataset(
                dataset=dataset,
                edges=edges,
                dictionary=self.dictionary,
                n_entities=len(self.entity_dictionary),
                seed=self.args.seed,
                split=split,
            )
            dataset = PrependTokenDataset(
                dataset, self.dictionary.bos(), ['text'],
                ['mask_annotation', 'all_annotations'])
        else:
            raise Exception

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #10
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'), )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        dataset = TokenBlockAnnotatedDataset(
            annotated_text=annotated_text,
            max_positions=self.max_positions() - 5,  # <cls>, e1/e2 start/end
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            seed=self.seed,
            document_sep_len=1,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_whole_words else None

        src_dataset, tgt_dataset = CustomMaskTokensDataset.apply_mask(
            dataset,
            self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            seed=self.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        dataset = DictionaryDataset(
            {
                'id':
                IdDataset(),
                'src_tokens':
                PadDataset(
                    src_dataset,
                    pad_idx=self.source_dictionary.pad(),
                    left_pad=False,
                ),
                'src_lengths':
                NumelDataset(src_dataset, reduce=False),
                'target':
                PadDataset(
                    tgt_dataset,
                    pad_idx=self.source_dictionary.pad(),
                    left_pad=False,
                ),
                'nsentences':
                NumSamplesDataset(),
                'ntokens':
                NumelDataset(src_dataset, reduce=True),
            },
            main_key='src_tokens',
        )

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #11
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'), )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'), )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy'))
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'), )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.data_path in [
                '../data/nki/bin-v5-threshold20',
                '../data/nki/bin-v5-threshold20-small'
        ]:
            similar_entities = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.candidates_remap.idx.npy'))
            similarity_scores = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.scores_remap.idx.npy'))
        else:
            raise Exception(
                "Top 1000 similar entities/scores data not available for the given dataset."
            )

        dataset = BoRDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            similar_entities=similar_entities,
            similarity_scores=similarity_scores,
            seed=self.args.seed,
            dictionary=self.dictionary,
            n_strong_candidates=self.args.n_strong_candidates,
            n_weak_candidates=self.args.n_weak_candidates,
            head_tail_weight=self.args.head_tail_weight,
            n_tries_entity=self.args.n_tries_entity,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(),
                                      ['textA', 'textB'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Beispiel #12
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'),
        )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy')
        )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'),
        )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GNNDataset(
            annotated_text=annotated_text,
            graph=graph,
            dictionary=self.dictionary,
            min_common_neighbors=self.args.min_common_neighbors,
            max_common_neighbors=self.args.max_common_neighbors,
            required_min_common_neighbors=getattr(self.args, 'required_min_common_neighbors', 1),
            max_entities_size=self.args.max_entities_size,
            max_entities_from_queue=self.args.max_entities_from_queue,
            cover_random_prob=self.args.cover_random_prob,
            total_negatives=self.args.total_negatives,
            max_hard_negatives=self.args.max_hard_negatives,
            max_tokens=self.args.max_tokens - 1, # for bos
            max_sentences=self.args.max_sentences,
            num_text_chunks=self.args.num_text_chunks,
            entity_pair_counter_cap=getattr(self.args, 'entity_pair_counter_cap', None),
            num_workers=self.args.num_workers,
            seed=self.args.seed,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), 'text')

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset