Exemplo n.º 1
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'), )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        relation_dataset = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.relations'))

        dataset = SemEval2010Task8Dataset(
            annotation_text=annotated_text,
            relation_dataset=relation_dataset,
            dictionary=self.dictionary,
            seed=self.seed,
        )
        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), ['text'],
                                      ['annotation'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 2
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'), )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = TripletDataset(
            annotated_text=annotated_text,
            graph=graph,
            k_negative=self.args.k_negative,
            n_entities=len(self.entity_dictionary),
            seed=self.args.seed,
            dictionary=self.dictionary,
            same_replace_heads_for_all_negatives=self.args.arch.startswith(
                'encoder_dual'),
            negative_split_probs=self.args.negative_split_probs or [1, 0, 0],
            use_sentence_negatives=self.args.use_sentence_negatives,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), 'text')

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 3
0
    def load_dataset(self, split, prune_type=None, prune_param=None, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'),
        )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'),
        )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        relation_dataset = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.relations')
        )

        dataset = FewRelDataset(
            annotation_text=annotated_text,
            relation_dataset=relation_dataset,
            dictionary=self.dictionary,
            n_way=self.args.n_way,
            n_shot=self.args.n_shot,
            seed=self.seed,
        )

        if prune_type == 'n_train_relations':
            assert prune_param is not None
            if prune_param < 64:
                dataset.prune_by_num_relations(prune_param)
        elif prune_type == 'n_train_examples_per_relation':
            assert prune_param is not None
            if prune_param < 700:
                dataset.prune_by_num_examples_per_relation(prune_param)

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), ['text', 'exemplars'], ['annotation', 'exemplars_annotation'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 4
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        edges = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'), )

        dataset = ETPRelationDataset(
            annotated_text=annotated_text,
            edges=edges,
            dictionary=self.dictionary,
            n_entities=len(self.entity_dictionary),
            total_negatives=self.args.total_negatives,
            mask_negative_prob=self.args.mask_negative_prob,
            max_positions=self.args.max_positions,
            num_workers=self.args.num_workers,
            seed=self.args.seed,
        )

        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), 'text',
                                      ['mask_annotation', 'all_annotations'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 5
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'), )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GNNEvalDataset(
            annotated_text=annotated_text,
            graph=graph,
            dictionary=self.dictionary,
            max_positions=self.args.max_positions,
            num_workers=self.args.num_workers,
            seed=self.args.seed,
        )

        dataset = PrependTokenDataset(dataset,
                                      self.dictionary.bos(),
                                      keys=['target', 'support'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 6
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'), )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        relation_dataset = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.relations'))

        dataset = TACREDDataset(
            annotation_text=annotated_text,
            relation_dataset=relation_dataset,
            dictionary=self.dictionary,
            seed=self.seed,
        )
        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), ['text'])
        dataset.annotated_text = annotated_text
        dataset.relation_dataset = relation_dataset

        probing_dataset = TACREDProbingDataset(
            tacred_dataset=dataset,
            n_rules=self.args.n_rules,
            n_texts=self.args.n_texts,
            n_strong_negs=self.args.n_strong_negs,
            dictionary=self.dictionary,
            seed=self.seed)

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            probing_dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = probing_dataset
Exemplo n.º 7
0
def main(args):
    dict_path = os.path.join(args.data_path, 'dict.txt')
    dictionary = CustomDictionary.load(dict_path)

    entity_dict_path = os.path.join(args.data_path, 'entity.dict.txt')
    entity_dictionary = EntityDictionary.load(entity_dict_path)

    logger.info('dictionary: {} types'.format(len(dictionary)))
    logger.info('entity dictionary: {} types'.format(len(entity_dictionary)))

    text_data = safe_load_indexed_dataset(
        os.path.join(args.data_path, args.split + '.text'), )
    annotation_data = MMapNumpyArray(
        os.path.join(args.data_path, args.split + '.annotations.npy'))
    annotated_text = AnnotatedText(
        text_data=text_data,
        annotation_data=annotation_data,
        dictionary=dictionary,
        mask_type=args.mask_type,
        non_mask_rate=args.non_mask_rate,
    )

    graph_data = safe_load_indexed_dataset(
        os.path.join(args.data_path, args.split + '.graph'), )
    graph = GraphDataset(
        edges=graph_data,
        subsampling_strategy=args.subsampling_strategy,
        subsampling_cap=args.subsampling_cap,
        seed=args.seed,
    )
    graph.set_epoch(1)

    entity_pair_counter_sum = 0

    with numpy_seed('SubgraphSampler', args.seed):
        random_perm = np.random.permutation(len(graph))

        if args.save_subgraph is not None:
            for index in random_perm:
                subgraph, _ = sample_subgraph(graph, annotated_text, index,
                                              None, None, args)
                if subgraph is not None:
                    break

            path = '%s.max_tokens=%d.max_sentences=%d.min_common_neighbors=%d' % (
                args.save_subgraph,
                args.max_tokens,
                args.max_sentences,
                args.min_common_neighbors,
            )
            save_subgraph(subgraph, dictionary, entity_dictionary, path,
                          args.save_text)
        else:
            num_subgraphs, total_edges, total_covered_edges = 0, 0, 0
            total_relative_coverage_mean, total_relative_coverage_median = 0, 0
            total_full_batch = 0
            entity_pair_counter, relation_statement_counter = Counter(
            ), Counter()
            with trange(len(graph), desc='Sampling') as progress_bar:
                for i in progress_bar:
                    subgraph, fill_successfully = sample_subgraph(
                        graph,
                        annotated_text,
                        random_perm[i],
                        entity_pair_counter,
                        entity_pair_counter_sum,
                        args,
                    )
                    if subgraph is None:
                        continue

                    num_subgraphs += 1
                    relation_statement_counter.update([
                        hash(x) for x in subgraph.relation_statements.values()
                    ])
                    # entity_pair_counter.update([(min(h, t), max(h, t)) for (h, t) in subgraph.relation_statements.keys()])
                    entity_pair_counter.update([
                        (h, t)
                        for (h, t) in subgraph.relation_statements.keys()
                    ])
                    entity_pair_counter_sum += len(
                        subgraph.relation_statements)
                    total_edges += len(subgraph.relation_statements)
                    total_covered_edges += len(subgraph.covered_entity_pairs)
                    relative_coverages = subgraph.relative_coverages()
                    total_relative_coverage_mean += np.mean(relative_coverages)
                    total_relative_coverage_median += np.median(
                        relative_coverages)
                    total_full_batch += int(fill_successfully)

                    entity_pairs_counts = np.array(
                        list(entity_pair_counter.values()))
                    relation_statement_counts = np.array(
                        list(relation_statement_counter.values()))

                    progress_bar.set_postfix(
                        # n=num_subgraphs,
                        mean=entity_pair_counter_sum / len(graph),
                        m_r=relation_statement_counter.most_common(1)[0][1],
                        m_e=entity_pair_counter.most_common(1)[0][1],
                        w_e=wasserstein_distance(
                            entity_pairs_counts,
                            np.ones_like(entity_pairs_counts)),
                        w_r=wasserstein_distance(
                            relation_statement_counts,
                            np.ones_like(relation_statement_counts)),
                        y=total_covered_edges / total_edges,
                        e=total_edges / num_subgraphs,
                        # cov_e=total_covered_edges / num_subgraphs,
                        rel_cov=total_relative_coverage_mean / num_subgraphs,
                        # rel_cov_median=total_relative_coverage_median / num_subgraphs,
                        # f=total_full_batch / num_subgraphs,
                    )
Exemplo n.º 8
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'),
        )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy')
        )
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'),
        )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'),
            )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'),
            )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy')
            )
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'),
            )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GraphDistanceDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            seed=self.args.seed,
            dictionary=self.dictionary,
            class_probabilities=self.args.class_probabilities,
            n_tries_entity=self.args.n_tries_entity
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(
            dataset,
            self.dictionary.bos(),
            ['textA', 'textB']
        )

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 9
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'), )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'), )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy'))
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'), )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if (self.args.strong_negatives and self.args.strong_negative_type
                == 'similarity') or self.args.similar_positives:
            similar_entities = MMapNumpyArray(
                os.path.join(self.args.data_path, 'entity.candidates.idx.npy'))
            similarity_scores = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.candidates.scores.npy'))
        else:
            similar_entities = None
            similarity_scores = None

        dataset = PMTBDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            similar_entities=similar_entities,
            similarity_scores=similarity_scores,
            seed=self.args.seed,
            dictionary=self.dictionary,
            entity_dictionary=self.entity_dictionary,
            k_weak_negs=self.args.k_weak_negs,
            n_tries_entity=self.args.n_tries_entity,
            strong_negatives=self.args.strong_negatives,
            strong_negative_type=self.args.strong_negative_type,
            negative_temperature=getattr(self.args, 'negative_temperature',
                                         None),
            replace_tail=self.args.replace_tail,
            mutual_positives=self.args.mutual_positives,
            similar_positives=self.args.similar_positives,
            positive_temperature=getattr(self.args, 'positive_temperature',
                                         None),
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(),
                                      ['textA', 'textB'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 10
0
    def load_dataset(self,
                     split,
                     prune_type=None,
                     prune_param=None,
                     epoch=0,
                     combine=False,
                     **kwargs):

        questions = safe_load_indexed_dataset(
            os.path.join(self.args.qa_data_path,
                         split + '.questions_entities'), )

        answers = MMapNumpyArray(
            os.path.join(self.args.qa_data_path,
                         split + '.answer_entities.npy'), )

        with open(
                os.path.join(self.args.qa_data_path,
                             split + '.processed_annotations.json')) as f:
            annotations = json.load(f)

        dataset = TriviaQADataset(questions, answers, annotations)

        task_framing = self.args.task_framing

        if task_framing == 'predict_mask':
            edges = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
            dataset = ETPDownstreamDataset(
                dataset=dataset,
                edges=edges,
                dictionary=self.dictionary,
                n_entities=len(self.entity_dictionary),
                seed=self.args.seed,
                split=split,
            )
            dataset = PrependTokenDataset(dataset, self.dictionary.bos(),
                                          ['text'], ['annotation'])

        elif task_framing == 'predict_mask_relation':
            edges = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
            dataset = ETPRelationDownstreamDataset(
                dataset=dataset,
                edges=edges,
                dictionary=self.dictionary,
                n_entities=len(self.entity_dictionary),
                seed=self.args.seed,
                split=split,
            )
            dataset = PrependTokenDataset(
                dataset, self.dictionary.bos(), ['text'],
                ['mask_annotation', 'all_annotations'])
        else:
            raise Exception

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 11
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'), )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )
        dataset = TokenBlockAnnotatedDataset(
            annotated_text=annotated_text,
            max_positions=self.max_positions() - 5,  # <cls>, e1/e2 start/end
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            seed=self.seed,
            document_sep_len=1,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_whole_words else None

        src_dataset, tgt_dataset = CustomMaskTokensDataset.apply_mask(
            dataset,
            self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            seed=self.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        dataset = DictionaryDataset(
            {
                'id':
                IdDataset(),
                'src_tokens':
                PadDataset(
                    src_dataset,
                    pad_idx=self.source_dictionary.pad(),
                    left_pad=False,
                ),
                'src_lengths':
                NumelDataset(src_dataset, reduce=False),
                'target':
                PadDataset(
                    tgt_dataset,
                    pad_idx=self.source_dictionary.pad(),
                    left_pad=False,
                ),
                'nsentences':
                NumSamplesDataset(),
                'ntokens':
                NumelDataset(src_dataset, reduce=True),
            },
            main_key='src_tokens',
        )

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 12
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'), )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'), )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy'))
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'), )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.data_path in [
                '../data/nki/bin-v5-threshold20',
                '../data/nki/bin-v5-threshold20-small'
        ]:
            similar_entities = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.candidates_remap.idx.npy'))
            similarity_scores = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.scores_remap.idx.npy'))
        else:
            raise Exception(
                "Top 1000 similar entities/scores data not available for the given dataset."
            )

        dataset = BoRDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            similar_entities=similar_entities,
            similarity_scores=similarity_scores,
            seed=self.args.seed,
            dictionary=self.dictionary,
            n_strong_candidates=self.args.n_strong_candidates,
            n_weak_candidates=self.args.n_weak_candidates,
            head_tail_weight=self.args.head_tail_weight,
            n_tries_entity=self.args.n_tries_entity,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(),
                                      ['textA', 'textB'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Exemplo n.º 13
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'),
        )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy')
        )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'),
        )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GNNDataset(
            annotated_text=annotated_text,
            graph=graph,
            dictionary=self.dictionary,
            min_common_neighbors=self.args.min_common_neighbors,
            max_common_neighbors=self.args.max_common_neighbors,
            required_min_common_neighbors=getattr(self.args, 'required_min_common_neighbors', 1),
            max_entities_size=self.args.max_entities_size,
            max_entities_from_queue=self.args.max_entities_from_queue,
            cover_random_prob=self.args.cover_random_prob,
            total_negatives=self.args.total_negatives,
            max_hard_negatives=self.args.max_hard_negatives,
            max_tokens=self.args.max_tokens - 1, # for bos
            max_sentences=self.args.max_sentences,
            num_text_chunks=self.args.num_text_chunks,
            entity_pair_counter_cap=getattr(self.args, 'entity_pair_counter_cap', None),
            num_workers=self.args.num_workers,
            seed=self.args.seed,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), 'text')

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset