Example #1
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'), )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = TripletDataset(
            annotated_text=annotated_text,
            graph=graph,
            k_negative=self.args.k_negative,
            n_entities=len(self.entity_dictionary),
            seed=self.args.seed,
            dictionary=self.dictionary,
            same_replace_heads_for_all_negatives=self.args.arch.startswith(
                'encoder_dual'),
            negative_split_probs=self.args.negative_split_probs or [1, 0, 0],
            use_sentence_negatives=self.args.use_sentence_negatives,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), 'text')

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Example #2
0
def _load_dataset(pickle_file='',
                  graph_file='',
                  label_file='',
                  target_file='',
                  n=-1):
    if '.pickle' in pickle_file:
        with open(pickle_file, 'r') as f:
            # pickle with the keys: adj_lists, labels, targets, unique_labels_dict
            pickled_dict = pickle.load(f)
            return GraphDataset(**pickled_dict)
    else:
        adj_lists = load_adj_lists(graph_file, n)
        labels, labels_dict = load_discrete_labels(label_file, n)
        targets = load_targets(target_file, n=n)
        graph_list = gen_graph_list(adj_lists, labels, labels_dict)
    return graph_list, targets
Example #3
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'), )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GNNEvalDataset(
            annotated_text=annotated_text,
            graph=graph,
            dictionary=self.dictionary,
            max_positions=self.args.max_positions,
            num_workers=self.args.num_workers,
            seed=self.args.seed,
        )

        dataset = PrependTokenDataset(dataset,
                                      self.dictionary.bos(),
                                      keys=['target', 'support'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
def main(args):
    dict_path = os.path.join(args.data_path, 'dict.txt')
    dictionary = CustomDictionary.load(dict_path)

    entity_dict_path = os.path.join(args.data_path, 'entity.dict.txt')
    entity_dictionary = EntityDictionary.load(entity_dict_path)

    logger.info('dictionary: {} types'.format(len(dictionary)))
    logger.info('entity dictionary: {} types'.format(len(entity_dictionary)))

    text_data = safe_load_indexed_dataset(
        os.path.join(args.data_path, args.split + '.text'), )
    annotation_data = MMapNumpyArray(
        os.path.join(args.data_path, args.split + '.annotations.npy'))
    annotated_text = AnnotatedText(
        text_data=text_data,
        annotation_data=annotation_data,
        dictionary=dictionary,
        mask_type=args.mask_type,
        non_mask_rate=args.non_mask_rate,
    )

    graph_data = safe_load_indexed_dataset(
        os.path.join(args.data_path, args.split + '.graph'), )
    graph = GraphDataset(
        edges=graph_data,
        subsampling_strategy=args.subsampling_strategy,
        subsampling_cap=args.subsampling_cap,
        seed=args.seed,
    )
    graph.set_epoch(1)

    entity_pair_counter_sum = 0

    with numpy_seed('SubgraphSampler', args.seed):
        random_perm = np.random.permutation(len(graph))

        if args.save_subgraph is not None:
            for index in random_perm:
                subgraph, _ = sample_subgraph(graph, annotated_text, index,
                                              None, None, args)
                if subgraph is not None:
                    break

            path = '%s.max_tokens=%d.max_sentences=%d.min_common_neighbors=%d' % (
                args.save_subgraph,
                args.max_tokens,
                args.max_sentences,
                args.min_common_neighbors,
            )
            save_subgraph(subgraph, dictionary, entity_dictionary, path,
                          args.save_text)
        else:
            num_subgraphs, total_edges, total_covered_edges = 0, 0, 0
            total_relative_coverage_mean, total_relative_coverage_median = 0, 0
            total_full_batch = 0
            entity_pair_counter, relation_statement_counter = Counter(
            ), Counter()
            with trange(len(graph), desc='Sampling') as progress_bar:
                for i in progress_bar:
                    subgraph, fill_successfully = sample_subgraph(
                        graph,
                        annotated_text,
                        random_perm[i],
                        entity_pair_counter,
                        entity_pair_counter_sum,
                        args,
                    )
                    if subgraph is None:
                        continue

                    num_subgraphs += 1
                    relation_statement_counter.update([
                        hash(x) for x in subgraph.relation_statements.values()
                    ])
                    # entity_pair_counter.update([(min(h, t), max(h, t)) for (h, t) in subgraph.relation_statements.keys()])
                    entity_pair_counter.update([
                        (h, t)
                        for (h, t) in subgraph.relation_statements.keys()
                    ])
                    entity_pair_counter_sum += len(
                        subgraph.relation_statements)
                    total_edges += len(subgraph.relation_statements)
                    total_covered_edges += len(subgraph.covered_entity_pairs)
                    relative_coverages = subgraph.relative_coverages()
                    total_relative_coverage_mean += np.mean(relative_coverages)
                    total_relative_coverage_median += np.median(
                        relative_coverages)
                    total_full_batch += int(fill_successfully)

                    entity_pairs_counts = np.array(
                        list(entity_pair_counter.values()))
                    relation_statement_counts = np.array(
                        list(relation_statement_counter.values()))

                    progress_bar.set_postfix(
                        # n=num_subgraphs,
                        mean=entity_pair_counter_sum / len(graph),
                        m_r=relation_statement_counter.most_common(1)[0][1],
                        m_e=entity_pair_counter.most_common(1)[0][1],
                        w_e=wasserstein_distance(
                            entity_pairs_counts,
                            np.ones_like(entity_pairs_counts)),
                        w_r=wasserstein_distance(
                            relation_statement_counts,
                            np.ones_like(relation_statement_counts)),
                        y=total_covered_edges / total_edges,
                        e=total_edges / num_subgraphs,
                        # cov_e=total_covered_edges / num_subgraphs,
                        rel_cov=total_relative_coverage_mean / num_subgraphs,
                        # rel_cov_median=total_relative_coverage_median / num_subgraphs,
                        # f=total_full_batch / num_subgraphs,
                    )
Example #5
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'),
        )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy')
        )
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'),
        )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'),
            )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'),
            )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy')
            )
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'),
            )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GraphDistanceDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            seed=self.args.seed,
            dictionary=self.dictionary,
            class_probabilities=self.args.class_probabilities,
            n_tries_entity=self.args.n_tries_entity
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(
            dataset,
            self.dictionary.bos(),
            ['textA', 'textB']
        )

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Example #6
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'), )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'), )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy'))
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'), )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if (self.args.strong_negatives and self.args.strong_negative_type
                == 'similarity') or self.args.similar_positives:
            similar_entities = MMapNumpyArray(
                os.path.join(self.args.data_path, 'entity.candidates.idx.npy'))
            similarity_scores = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.candidates.scores.npy'))
        else:
            similar_entities = None
            similarity_scores = None

        dataset = PMTBDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            similar_entities=similar_entities,
            similarity_scores=similarity_scores,
            seed=self.args.seed,
            dictionary=self.dictionary,
            entity_dictionary=self.entity_dictionary,
            k_weak_negs=self.args.k_weak_negs,
            n_tries_entity=self.args.n_tries_entity,
            strong_negatives=self.args.strong_negatives,
            strong_negative_type=self.args.strong_negative_type,
            negative_temperature=getattr(self.args, 'negative_temperature',
                                         None),
            replace_tail=self.args.replace_tail,
            mutual_positives=self.args.mutual_positives,
            similar_positives=self.args.similar_positives,
            positive_temperature=getattr(self.args, 'positive_temperature',
                                         None),
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(),
                                      ['textA', 'textB'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Example #7
0
# Define loss function
criterion = torch.nn.MultiLabelSoftMarginLoss().to(device)

# Training phase
if args.phase == 'train':

    # Create directories for checkpoint, sample and logs files
    ckpt_dir = args.model_dir + '/checkpoint'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    logs_dir = args.model_dir + '/logs'

    # Data loading
    print("[*] Loading training and validation data...")
    if args.net_type == 'gcn' or args.net_type == 'chebcn' or args.net_type == 'gmmcn' or args.net_type == 'gincn':
        train_set = GraphDataset(args.train_file, args.feats_dir,
                                 args.feats_type, args.edges_type)
        train_loader = pygeoDataLoader(train_set,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=1,
                                       drop_last=True)
        train_loader_eval = pygeoDataLoader(train_set,
                                            batch_size=1,
                                            shuffle=False)
        valid_set = GraphDataset(args.valid_file, args.feats_dir,
                                 args.feats_type, args.edges_type)
        valid_loader = pygeoDataLoader(valid_set, batch_size=1, shuffle=False)

    elif args.net_type == 'cnn1d' or args.net_type == 'deepgoplus':
        train_set = CNN1DDataset(args.train_file, args.feats_dir,
                                 args.feats_type)
Example #8
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):

        text_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'), )
        annotation_data_A = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy'))
        annotated_text_A = AnnotatedText(
            text_data=text_data_A,
            annotation_data=annotation_data_A,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data_A = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, 'mtb_' + split + '.graph'), )
        graph_A = GraphDataset(
            edges=graph_data_A,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.split_mode:
            annotated_text_B = annotated_text_A
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, split + '.graph'), )
        else:
            text_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.text'), )
            annotation_data_B = MMapNumpyArray(
                os.path.join(self.args.data_path, 'train.annotations.npy'))
            annotated_text_B = AnnotatedText(
                text_data=text_data_B,
                annotation_data=annotation_data_B,
                dictionary=self.dictionary,
                mask_type=self.args.mask_type,
                non_mask_rate=self.args.non_mask_rate,
            )
            graph_data_B = safe_load_indexed_dataset(
                os.path.join(self.args.data_path, 'train.graph'), )

        graph_B = GraphDataset(
            edges=graph_data_B,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        if self.args.data_path in [
                '../data/nki/bin-v5-threshold20',
                '../data/nki/bin-v5-threshold20-small'
        ]:
            similar_entities = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.candidates_remap.idx.npy'))
            similarity_scores = MMapNumpyArray(
                os.path.join(self.args.data_path,
                             'entity.scores_remap.idx.npy'))
        else:
            raise Exception(
                "Top 1000 similar entities/scores data not available for the given dataset."
            )

        dataset = BoRDataset(
            split=split,
            annotated_text_A=annotated_text_A,
            annotated_text_B=annotated_text_B,
            graph_A=graph_A,
            graph_B=graph_B,
            similar_entities=similar_entities,
            similarity_scores=similarity_scores,
            seed=self.args.seed,
            dictionary=self.dictionary,
            n_strong_candidates=self.args.n_strong_candidates,
            n_weak_candidates=self.args.n_weak_candidates,
            head_tail_weight=self.args.head_tail_weight,
            n_tries_entity=self.args.n_tries_entity,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(),
                                      ['textA', 'textB'])

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset
Example #9
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        text_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.text'),
        )
        annotation_data = MMapNumpyArray(
            os.path.join(self.args.data_path, split + '.annotations.npy')
        )
        annotated_text = AnnotatedText(
            text_data=text_data,
            annotation_data=annotation_data,
            dictionary=self.dictionary,
            mask_type=self.args.mask_type,
            non_mask_rate=self.args.non_mask_rate,
        )

        graph_data = safe_load_indexed_dataset(
            os.path.join(self.args.data_path, split + '.graph'),
        )
        graph = GraphDataset(
            edges=graph_data,
            subsampling_strategy=self.args.subsampling_strategy,
            subsampling_cap=self.args.subsampling_cap,
            seed=self.args.seed,
        )

        dataset = GNNDataset(
            annotated_text=annotated_text,
            graph=graph,
            dictionary=self.dictionary,
            min_common_neighbors=self.args.min_common_neighbors,
            max_common_neighbors=self.args.max_common_neighbors,
            required_min_common_neighbors=getattr(self.args, 'required_min_common_neighbors', 1),
            max_entities_size=self.args.max_entities_size,
            max_entities_from_queue=self.args.max_entities_from_queue,
            cover_random_prob=self.args.cover_random_prob,
            total_negatives=self.args.total_negatives,
            max_hard_negatives=self.args.max_hard_negatives,
            max_tokens=self.args.max_tokens - 1, # for bos
            max_sentences=self.args.max_sentences,
            num_text_chunks=self.args.num_text_chunks,
            entity_pair_counter_cap=getattr(self.args, 'entity_pair_counter_cap', None),
            num_workers=self.args.num_workers,
            seed=self.args.seed,
        )
        if split == 'train' and self.args.epoch_size is not None:
            dataset = EpochSplitDataset(
                dataset=dataset,
                epoch_size=self.args.epoch_size,
                seed=self.args.seed,
            )

        dataset = PrependTokenDataset(dataset, self.dictionary.bos(), 'text')

        n_examples = getattr(self.args, 'n_' + split + '_examples', None)
        if n_examples is not None:
            dataset = FixedSizeDataset(
                dataset=dataset,
                size=n_examples,
                seed=self.seed,
            )

        self.datasets[split] = dataset