def __getitem__(self, index):
        with numpy_seed('EntityPredictionDataset', self.seed, self.epoch,
                        index):
            sampled_edge = None
            while not sampled_edge:
                entity = np.random.randint(len(self.edges))
                n_entity_edges = len(
                    self.edges[entity]) // GraphDataset.EDGE_SIZE
                if n_entity_edges > 0:
                    passage_idx = np.random.randint(n_entity_edges)
                    edge_start = passage_idx * GraphDataset.EDGE_SIZE
                    edge = self.edges[entity][edge_start:edge_start +
                                              GraphDataset.EDGE_SIZE].numpy()
                    sampled_edge = True

            start_pos, end_pos, start_block, end_block = edge[
                GraphDataset.HEAD_START_POS], edge[
                    GraphDataset.HEAD_END_POS], edge[
                        GraphDataset.START_BLOCK], edge[GraphDataset.END_BLOCK]
            passage, annotation_position = self.annotated_text.annotate_mention(
                entity, start_pos, end_pos, start_block, end_block)

        item = {
            'text':
            passage,
            'annotation':
            torch.LongTensor(annotation_position)
            if annotation_position else None,
            'target':
            entity
        }
        return item
예제 #2
0
    def __getitem__(self, index):
        with numpy_seed('TokenBlockAnnotatedDataset', self.seed, self.epoch,
                        index):
            start_block, end_block = self._slice_indices.array[index]
            end_block = min(start_block + self.max_positions, end_block)

            annotations = self.dataset.annotations_block(
                start_block, end_block)
            if len(annotations) > 0:
                entities = np.unique(
                    annotations[:, AnnotatedText.INDEX_ANNOTATION_ENTITY])
            else:
                entities = []
            head_entity, tail_entity = self.sample_entities(entities)
            head_start_pos, head_end_pos = self.sample_annotation(
                annotations, head_entity)
            tail_start_pos, tail_end_pos = self.sample_annotation(
                annotations, tail_entity)

            text = self.dataset.annotate_relation(
                tail_entity=tail_entity,
                head_entity=head_entity,
                head_start_pos=head_start_pos,
                head_end_pos=head_end_pos,
                tail_start_pos=tail_start_pos,
                tail_end_pos=tail_end_pos,
                start_block=start_block,
                end_block=end_block,
                annotations=annotations,
            )
            return text
예제 #3
0
 def set_epoch(self, epoch):
     if epoch != self.epoch:
         with numpy_seed('GraphDataset', self.seed, epoch):
             self._indices, self._sizes = self.subsample_graph_by_entity_pairs(
             )
         self._indices = maybe_move_to_plasma(self._indices)
         self._sizes = maybe_move_to_plasma(self._sizes)
         self.epoch = epoch
예제 #4
0
    def __getitem__(self, index):
        with numpy_seed('ETPRelationDataset', self.seed, self.epoch, index):
            sampled_edge = None
            while not sampled_edge:
                entity = np.random.randint(len(self.edges))
                n_entity_edges = len(
                    self.edges[entity]) // GraphDataset.EDGE_SIZE
                if n_entity_edges > 0:
                    passage_idx = np.random.randint(n_entity_edges)
                    edge_start = passage_idx * GraphDataset.EDGE_SIZE
                    edge = self.edges[entity][edge_start:edge_start +
                                              GraphDataset.EDGE_SIZE].numpy()
                    sampled_edge = True

            start_pos, end_pos, start_block, end_block = edge[
                GraphDataset.HEAD_START_POS], edge[
                    GraphDataset.HEAD_END_POS], edge[
                        GraphDataset.START_BLOCK], edge[GraphDataset.END_BLOCK]
            passage, mask_annotation_position, all_annotation_positions, entity_ids = self.annotated_text.annotate_mention(
                entity,
                start_pos,
                end_pos,
                start_block,
                end_block,
                return_all_annotations=True)

            mask_idx = all_annotation_positions.index(
                mask_annotation_position[0])

            replace_probs = np.ones(len(all_annotation_positions)) * (
                1 - self.mask_negative_prob) / (len(all_annotation_positions) -
                                                1)
            replace_probs[mask_idx] = self.mask_negative_prob
            replace_position = np.random.choice(len(all_annotation_positions),
                                                size=1,
                                                p=replace_probs)

            entity_replacements = np.random.choice(self.n_entities,
                                                   replace=False,
                                                   size=self.total_negatives)

            # candidates = np.expand_dims(entity_ids, axis=-1).repeat(self.total_negatives, axis=-1)
            # candidates[replace_position, np.arange(len(replace_position))] = entity_replacements

        item = {
            'text': passage,
            'mask_annotation': torch.LongTensor(mask_annotation_position),
            'all_annotations': torch.LongTensor(all_annotation_positions),
            'entity_ids': torch.LongTensor(entity_ids),
            'entity_replacements': torch.LongTensor(entity_replacements),
            'replacement_position': replace_position,
            # 'candidates': torch.LongTensor(candidates),
        }
        return item
예제 #5
0
 def set_epoch(self, epoch):
     self.dataset.set_epoch(epoch)
     with numpy_seed('FixedSizeDataset', self.seed):
         if len(self.dataset) <= self._size:
             self.data_indices = np.arange(len(self.dataset))
         else:
             self.data_indices = np.random.choice(
                 len(self.dataset),
                 self._size,
                 replace=False,
             )
     self._sizes = self.dataset.sizes[self.data_indices]
예제 #6
0
    def __getitem__(self, index):
        with numpy_seed('GNNEvalDataset', self.seed, self.epoch, index):
            local_interval = IntervalTree()
            edge = self.graph[index]
            head = edge[GraphDataset.HEAD_ENTITY]
            tail = edge[GraphDataset.TAIL_ENTITY]

            start = edge[GraphDataset.START_BLOCK]
            end = edge[GraphDataset.END_BLOCK]
            local_interval.addi(start, end)
            head_neighbors = self.graph.get_neighbors(head)
            tail_neighbors = self.graph.get_neighbors(tail)

            mutual_neighbors = np.intersect1d(head_neighbors,
                                              tail_neighbors,
                                              assume_unique=True)
            if len(mutual_neighbors) == 0:
                return None

            found_supporting = False
            random_mutual = np.random.permutation(mutual_neighbors)

            for chosen_mutual in random_mutual:
                support1, local_interval = self.sample_relation_statement(
                    head, chosen_mutual, local_interval)
                support2, local_interval = self.sample_relation_statement(
                    chosen_mutual, tail, local_interval)

                if support1 is None or support2 is None:
                    continue
                else:
                    found_supporting = True
                    break

            if found_supporting is False:
                return None

        item = {
            'target':
            self.annotated_text.annotate_relation(*(edge.numpy())),
            'support': [
                self.annotated_text.annotate_relation(*(support1)),
                self.annotated_text.annotate_relation(*(support2))
            ],
            'entities': {
                'A': head,
                'B': tail,
                'C': chosen_mutual
            }
        }

        return item
예제 #7
0
    def set_epoch(self, epoch):
        if epoch == self.epoch:
            return

        assert epoch >= 1
        self.epoch = epoch

        if self.epoch_splits is None:
            self.set_dataset_epoch(1)
            self.epoch_splits = math.ceil(len(self.dataset) / self.epoch_size)
            assert self.epoch_splits >= 1
            logger.info('set epoch_split to be %d given epoch_size=%d, dataset size=%d' % (
                self.epoch_splits,
                len(self.dataset),
                self.epoch_size
            ))

        dataset_epoch = ((epoch - 1) // self.epoch_splits) + 1
        epoch_offset = (epoch - 1) % self.epoch_splits

        self.set_dataset_epoch(dataset_epoch)

        start_time = time.time()
        data_per_epoch = len(self.dataset) // (self.epoch_splits or 1)
        data_start = data_per_epoch * epoch_offset
        data_end = min(len(self.dataset), data_per_epoch * (epoch_offset + 1))

        with numpy_seed('EpochSplitDataset', self.seed, self.dataset.epoch):
            dataset_indices = np.random.permutation(len(self.dataset))
        self._indices = dataset_indices[data_start:data_end]
        self._sizes = self.dataset.sizes[self._indices]

        self._indices = maybe_move_to_plasma(self._indices)
        self._sizes = maybe_move_to_plasma(self._sizes)

        logger.info('selected %d samples from generation epoch %d and epoch offset %d in %d seconds' % (
            data_end - data_start,
            self.dataset.epoch,
            epoch_offset,
            time.time() - start_time,
        ))
예제 #8
0
    def __getitem__(self, index):
        with numpy_seed('GNNDataset', self.seed, self.epoch, index):
            subgraph = self._sample_subgraph(index)
            while subgraph is None:
                # logging.warning('Failed to sample subgraph for [seed=%d, epoch=%d, index=%d]' % (
                #     self.seed,
                #     self.epoch,
                #     index,
                # ))
                text_index = np.random.randint(len(self.graph))
                subgraph = self._sample_subgraph(text_index)

        sentences, text_index = self._get_all_sentences_and_index(subgraph)
        graph, graph_sizes, candidate_text_idx, logging_output = self._make_negatives(
            subgraph, text_index)

        item = {
            'text':
            sentences,
            'graph':
            graph,
            'graph_sizes':
            graph_sizes,
            'candidate_text_idx':
            candidate_text_idx,
            'target':
            torch.zeros(len(candidate_text_idx), dtype=torch.int64),
            'all_entity_pairs':
            torch.tensor(list(text_index.keys()), dtype=torch.int32),
            'yield':
            subgraph.get_yield(),
            'rel_cov':
            subgraph.get_relative_coverages_mean(),
            'nsentences':
            subgraph.nsentences,
            'ntokens':
            subgraph.ntokens,
        }
        item.update(logging_output)
        return item
예제 #9
0
def main(args):
    dict_path = os.path.join(args.data_path, 'dict.txt')
    dictionary = CustomDictionary.load(dict_path)

    entity_dict_path = os.path.join(args.data_path, 'entity.dict.txt')
    entity_dictionary = EntityDictionary.load(entity_dict_path)

    logger.info('dictionary: {} types'.format(len(dictionary)))
    logger.info('entity dictionary: {} types'.format(len(entity_dictionary)))

    text_data = safe_load_indexed_dataset(
        os.path.join(args.data_path, args.split + '.text'), )
    annotation_data = MMapNumpyArray(
        os.path.join(args.data_path, args.split + '.annotations.npy'))
    annotated_text = AnnotatedText(
        text_data=text_data,
        annotation_data=annotation_data,
        dictionary=dictionary,
        mask_type=args.mask_type,
        non_mask_rate=args.non_mask_rate,
    )

    graph_data = safe_load_indexed_dataset(
        os.path.join(args.data_path, args.split + '.graph'), )
    graph = GraphDataset(
        edges=graph_data,
        subsampling_strategy=args.subsampling_strategy,
        subsampling_cap=args.subsampling_cap,
        seed=args.seed,
    )
    graph.set_epoch(1)

    entity_pair_counter_sum = 0

    with numpy_seed('SubgraphSampler', args.seed):
        random_perm = np.random.permutation(len(graph))

        if args.save_subgraph is not None:
            for index in random_perm:
                subgraph, _ = sample_subgraph(graph, annotated_text, index,
                                              None, None, args)
                if subgraph is not None:
                    break

            path = '%s.max_tokens=%d.max_sentences=%d.min_common_neighbors=%d' % (
                args.save_subgraph,
                args.max_tokens,
                args.max_sentences,
                args.min_common_neighbors,
            )
            save_subgraph(subgraph, dictionary, entity_dictionary, path,
                          args.save_text)
        else:
            num_subgraphs, total_edges, total_covered_edges = 0, 0, 0
            total_relative_coverage_mean, total_relative_coverage_median = 0, 0
            total_full_batch = 0
            entity_pair_counter, relation_statement_counter = Counter(
            ), Counter()
            with trange(len(graph), desc='Sampling') as progress_bar:
                for i in progress_bar:
                    subgraph, fill_successfully = sample_subgraph(
                        graph,
                        annotated_text,
                        random_perm[i],
                        entity_pair_counter,
                        entity_pair_counter_sum,
                        args,
                    )
                    if subgraph is None:
                        continue

                    num_subgraphs += 1
                    relation_statement_counter.update([
                        hash(x) for x in subgraph.relation_statements.values()
                    ])
                    # entity_pair_counter.update([(min(h, t), max(h, t)) for (h, t) in subgraph.relation_statements.keys()])
                    entity_pair_counter.update([
                        (h, t)
                        for (h, t) in subgraph.relation_statements.keys()
                    ])
                    entity_pair_counter_sum += len(
                        subgraph.relation_statements)
                    total_edges += len(subgraph.relation_statements)
                    total_covered_edges += len(subgraph.covered_entity_pairs)
                    relative_coverages = subgraph.relative_coverages()
                    total_relative_coverage_mean += np.mean(relative_coverages)
                    total_relative_coverage_median += np.median(
                        relative_coverages)
                    total_full_batch += int(fill_successfully)

                    entity_pairs_counts = np.array(
                        list(entity_pair_counter.values()))
                    relation_statement_counts = np.array(
                        list(relation_statement_counter.values()))

                    progress_bar.set_postfix(
                        # n=num_subgraphs,
                        mean=entity_pair_counter_sum / len(graph),
                        m_r=relation_statement_counter.most_common(1)[0][1],
                        m_e=entity_pair_counter.most_common(1)[0][1],
                        w_e=wasserstein_distance(
                            entity_pairs_counts,
                            np.ones_like(entity_pairs_counts)),
                        w_r=wasserstein_distance(
                            relation_statement_counts,
                            np.ones_like(relation_statement_counts)),
                        y=total_covered_edges / total_edges,
                        e=total_edges / num_subgraphs,
                        # cov_e=total_covered_edges / num_subgraphs,
                        rel_cov=total_relative_coverage_mean / num_subgraphs,
                        # rel_cov_median=total_relative_coverage_median / num_subgraphs,
                        # f=total_full_batch / num_subgraphs,
                    )
예제 #10
0
    def get_batch_iterator(
        self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
        ignore_invalid_inputs=False, required_batch_size_multiple=1,
        seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0,
    ):
        """
        Get an iterator that yields batches of data from the given dataset.

        Args:
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_tokens (int, optional): max number of tokens in each batch
                (default: None).
            max_sentences (int, optional): max number of sentences in each
                batch (default: None).
            max_positions (optional): max sentence length supported by the
                model (default: None).
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long (default: False).
            required_batch_size_multiple (int, optional): require batch size to
                be a multiple of N (default: 1).
            seed (int, optional): seed for random number generator for
                reproducibility (default: 1).
            num_shards (int, optional): shard the data iterator into N
                shards (default: 1).
            shard_id (int, optional): which shard of the data iterator to
                return (default: 0).
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means the data will be loaded in the main process
                (default: 0).
            epoch (int, optional): the epoch to start the iterator from
                (default: 0).
        Returns:
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
        """
        assert isinstance(dataset, FairseqDataset)

        # initialize the dataset with the correct starting epoch
        global_start_time = time.time()
        dataset.set_epoch(epoch)

        # Horrible hack because fairseq wrapper doesn't set epoch for itself. TODO: take ownership of wrapper
        dataset.epoch = epoch
        set_epoch_time = time.time() - global_start_time

        # get indices ordered by example size
        start_time = time.time()
        with numpy_seed('R3LTask', seed, epoch):
            indices = dataset.ordered_indices()
        sort_time = time.time() - start_time

        # create mini-batches with given size constraints
        start_time = time.time()
        batch_sampler = np.expand_dims(indices, 1)
        batch_by_size_time = time.time() - start_time
        logger.info(
            'get batch iterator [seed=%d, epoch=%d, num_shards=%d] is done in %.3f seconds '
            '(set epoch=%.3f, sorting=%.3f, batch by size=%.3f)' % (
                seed,
                epoch,
                num_shards,
                time.time() - global_start_time,
                set_epoch_time,
                sort_time,
                batch_by_size_time,
        ))

        # return a reusable, sharded iterator
        epoch_iter = iterators.EpochBatchIterator(
            dataset=dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
            num_workers=num_workers,
            epoch=epoch,
        )
        return epoch_iter