def __getitem__(self, index): with numpy_seed('EntityPredictionDataset', self.seed, self.epoch, index): sampled_edge = None while not sampled_edge: entity = np.random.randint(len(self.edges)) n_entity_edges = len( self.edges[entity]) // GraphDataset.EDGE_SIZE if n_entity_edges > 0: passage_idx = np.random.randint(n_entity_edges) edge_start = passage_idx * GraphDataset.EDGE_SIZE edge = self.edges[entity][edge_start:edge_start + GraphDataset.EDGE_SIZE].numpy() sampled_edge = True start_pos, end_pos, start_block, end_block = edge[ GraphDataset.HEAD_START_POS], edge[ GraphDataset.HEAD_END_POS], edge[ GraphDataset.START_BLOCK], edge[GraphDataset.END_BLOCK] passage, annotation_position = self.annotated_text.annotate_mention( entity, start_pos, end_pos, start_block, end_block) item = { 'text': passage, 'annotation': torch.LongTensor(annotation_position) if annotation_position else None, 'target': entity } return item
def __getitem__(self, index): with numpy_seed('TokenBlockAnnotatedDataset', self.seed, self.epoch, index): start_block, end_block = self._slice_indices.array[index] end_block = min(start_block + self.max_positions, end_block) annotations = self.dataset.annotations_block( start_block, end_block) if len(annotations) > 0: entities = np.unique( annotations[:, AnnotatedText.INDEX_ANNOTATION_ENTITY]) else: entities = [] head_entity, tail_entity = self.sample_entities(entities) head_start_pos, head_end_pos = self.sample_annotation( annotations, head_entity) tail_start_pos, tail_end_pos = self.sample_annotation( annotations, tail_entity) text = self.dataset.annotate_relation( tail_entity=tail_entity, head_entity=head_entity, head_start_pos=head_start_pos, head_end_pos=head_end_pos, tail_start_pos=tail_start_pos, tail_end_pos=tail_end_pos, start_block=start_block, end_block=end_block, annotations=annotations, ) return text
def set_epoch(self, epoch): if epoch != self.epoch: with numpy_seed('GraphDataset', self.seed, epoch): self._indices, self._sizes = self.subsample_graph_by_entity_pairs( ) self._indices = maybe_move_to_plasma(self._indices) self._sizes = maybe_move_to_plasma(self._sizes) self.epoch = epoch
def __getitem__(self, index): with numpy_seed('ETPRelationDataset', self.seed, self.epoch, index): sampled_edge = None while not sampled_edge: entity = np.random.randint(len(self.edges)) n_entity_edges = len( self.edges[entity]) // GraphDataset.EDGE_SIZE if n_entity_edges > 0: passage_idx = np.random.randint(n_entity_edges) edge_start = passage_idx * GraphDataset.EDGE_SIZE edge = self.edges[entity][edge_start:edge_start + GraphDataset.EDGE_SIZE].numpy() sampled_edge = True start_pos, end_pos, start_block, end_block = edge[ GraphDataset.HEAD_START_POS], edge[ GraphDataset.HEAD_END_POS], edge[ GraphDataset.START_BLOCK], edge[GraphDataset.END_BLOCK] passage, mask_annotation_position, all_annotation_positions, entity_ids = self.annotated_text.annotate_mention( entity, start_pos, end_pos, start_block, end_block, return_all_annotations=True) mask_idx = all_annotation_positions.index( mask_annotation_position[0]) replace_probs = np.ones(len(all_annotation_positions)) * ( 1 - self.mask_negative_prob) / (len(all_annotation_positions) - 1) replace_probs[mask_idx] = self.mask_negative_prob replace_position = np.random.choice(len(all_annotation_positions), size=1, p=replace_probs) entity_replacements = np.random.choice(self.n_entities, replace=False, size=self.total_negatives) # candidates = np.expand_dims(entity_ids, axis=-1).repeat(self.total_negatives, axis=-1) # candidates[replace_position, np.arange(len(replace_position))] = entity_replacements item = { 'text': passage, 'mask_annotation': torch.LongTensor(mask_annotation_position), 'all_annotations': torch.LongTensor(all_annotation_positions), 'entity_ids': torch.LongTensor(entity_ids), 'entity_replacements': torch.LongTensor(entity_replacements), 'replacement_position': replace_position, # 'candidates': torch.LongTensor(candidates), } return item
def set_epoch(self, epoch): self.dataset.set_epoch(epoch) with numpy_seed('FixedSizeDataset', self.seed): if len(self.dataset) <= self._size: self.data_indices = np.arange(len(self.dataset)) else: self.data_indices = np.random.choice( len(self.dataset), self._size, replace=False, ) self._sizes = self.dataset.sizes[self.data_indices]
def __getitem__(self, index): with numpy_seed('GNNEvalDataset', self.seed, self.epoch, index): local_interval = IntervalTree() edge = self.graph[index] head = edge[GraphDataset.HEAD_ENTITY] tail = edge[GraphDataset.TAIL_ENTITY] start = edge[GraphDataset.START_BLOCK] end = edge[GraphDataset.END_BLOCK] local_interval.addi(start, end) head_neighbors = self.graph.get_neighbors(head) tail_neighbors = self.graph.get_neighbors(tail) mutual_neighbors = np.intersect1d(head_neighbors, tail_neighbors, assume_unique=True) if len(mutual_neighbors) == 0: return None found_supporting = False random_mutual = np.random.permutation(mutual_neighbors) for chosen_mutual in random_mutual: support1, local_interval = self.sample_relation_statement( head, chosen_mutual, local_interval) support2, local_interval = self.sample_relation_statement( chosen_mutual, tail, local_interval) if support1 is None or support2 is None: continue else: found_supporting = True break if found_supporting is False: return None item = { 'target': self.annotated_text.annotate_relation(*(edge.numpy())), 'support': [ self.annotated_text.annotate_relation(*(support1)), self.annotated_text.annotate_relation(*(support2)) ], 'entities': { 'A': head, 'B': tail, 'C': chosen_mutual } } return item
def set_epoch(self, epoch): if epoch == self.epoch: return assert epoch >= 1 self.epoch = epoch if self.epoch_splits is None: self.set_dataset_epoch(1) self.epoch_splits = math.ceil(len(self.dataset) / self.epoch_size) assert self.epoch_splits >= 1 logger.info('set epoch_split to be %d given epoch_size=%d, dataset size=%d' % ( self.epoch_splits, len(self.dataset), self.epoch_size )) dataset_epoch = ((epoch - 1) // self.epoch_splits) + 1 epoch_offset = (epoch - 1) % self.epoch_splits self.set_dataset_epoch(dataset_epoch) start_time = time.time() data_per_epoch = len(self.dataset) // (self.epoch_splits or 1) data_start = data_per_epoch * epoch_offset data_end = min(len(self.dataset), data_per_epoch * (epoch_offset + 1)) with numpy_seed('EpochSplitDataset', self.seed, self.dataset.epoch): dataset_indices = np.random.permutation(len(self.dataset)) self._indices = dataset_indices[data_start:data_end] self._sizes = self.dataset.sizes[self._indices] self._indices = maybe_move_to_plasma(self._indices) self._sizes = maybe_move_to_plasma(self._sizes) logger.info('selected %d samples from generation epoch %d and epoch offset %d in %d seconds' % ( data_end - data_start, self.dataset.epoch, epoch_offset, time.time() - start_time, ))
def __getitem__(self, index): with numpy_seed('GNNDataset', self.seed, self.epoch, index): subgraph = self._sample_subgraph(index) while subgraph is None: # logging.warning('Failed to sample subgraph for [seed=%d, epoch=%d, index=%d]' % ( # self.seed, # self.epoch, # index, # )) text_index = np.random.randint(len(self.graph)) subgraph = self._sample_subgraph(text_index) sentences, text_index = self._get_all_sentences_and_index(subgraph) graph, graph_sizes, candidate_text_idx, logging_output = self._make_negatives( subgraph, text_index) item = { 'text': sentences, 'graph': graph, 'graph_sizes': graph_sizes, 'candidate_text_idx': candidate_text_idx, 'target': torch.zeros(len(candidate_text_idx), dtype=torch.int64), 'all_entity_pairs': torch.tensor(list(text_index.keys()), dtype=torch.int32), 'yield': subgraph.get_yield(), 'rel_cov': subgraph.get_relative_coverages_mean(), 'nsentences': subgraph.nsentences, 'ntokens': subgraph.ntokens, } item.update(logging_output) return item
def main(args): dict_path = os.path.join(args.data_path, 'dict.txt') dictionary = CustomDictionary.load(dict_path) entity_dict_path = os.path.join(args.data_path, 'entity.dict.txt') entity_dictionary = EntityDictionary.load(entity_dict_path) logger.info('dictionary: {} types'.format(len(dictionary))) logger.info('entity dictionary: {} types'.format(len(entity_dictionary))) text_data = safe_load_indexed_dataset( os.path.join(args.data_path, args.split + '.text'), ) annotation_data = MMapNumpyArray( os.path.join(args.data_path, args.split + '.annotations.npy')) annotated_text = AnnotatedText( text_data=text_data, annotation_data=annotation_data, dictionary=dictionary, mask_type=args.mask_type, non_mask_rate=args.non_mask_rate, ) graph_data = safe_load_indexed_dataset( os.path.join(args.data_path, args.split + '.graph'), ) graph = GraphDataset( edges=graph_data, subsampling_strategy=args.subsampling_strategy, subsampling_cap=args.subsampling_cap, seed=args.seed, ) graph.set_epoch(1) entity_pair_counter_sum = 0 with numpy_seed('SubgraphSampler', args.seed): random_perm = np.random.permutation(len(graph)) if args.save_subgraph is not None: for index in random_perm: subgraph, _ = sample_subgraph(graph, annotated_text, index, None, None, args) if subgraph is not None: break path = '%s.max_tokens=%d.max_sentences=%d.min_common_neighbors=%d' % ( args.save_subgraph, args.max_tokens, args.max_sentences, args.min_common_neighbors, ) save_subgraph(subgraph, dictionary, entity_dictionary, path, args.save_text) else: num_subgraphs, total_edges, total_covered_edges = 0, 0, 0 total_relative_coverage_mean, total_relative_coverage_median = 0, 0 total_full_batch = 0 entity_pair_counter, relation_statement_counter = Counter( ), Counter() with trange(len(graph), desc='Sampling') as progress_bar: for i in progress_bar: subgraph, fill_successfully = sample_subgraph( graph, annotated_text, random_perm[i], entity_pair_counter, entity_pair_counter_sum, args, ) if subgraph is None: continue num_subgraphs += 1 relation_statement_counter.update([ hash(x) for x in subgraph.relation_statements.values() ]) # entity_pair_counter.update([(min(h, t), max(h, t)) for (h, t) in subgraph.relation_statements.keys()]) entity_pair_counter.update([ (h, t) for (h, t) in subgraph.relation_statements.keys() ]) entity_pair_counter_sum += len( subgraph.relation_statements) total_edges += len(subgraph.relation_statements) total_covered_edges += len(subgraph.covered_entity_pairs) relative_coverages = subgraph.relative_coverages() total_relative_coverage_mean += np.mean(relative_coverages) total_relative_coverage_median += np.median( relative_coverages) total_full_batch += int(fill_successfully) entity_pairs_counts = np.array( list(entity_pair_counter.values())) relation_statement_counts = np.array( list(relation_statement_counter.values())) progress_bar.set_postfix( # n=num_subgraphs, mean=entity_pair_counter_sum / len(graph), m_r=relation_statement_counter.most_common(1)[0][1], m_e=entity_pair_counter.most_common(1)[0][1], w_e=wasserstein_distance( entity_pairs_counts, np.ones_like(entity_pairs_counts)), w_r=wasserstein_distance( relation_statement_counts, np.ones_like(relation_statement_counts)), y=total_covered_edges / total_edges, e=total_edges / num_subgraphs, # cov_e=total_covered_edges / num_subgraphs, rel_cov=total_relative_coverage_mean / num_subgraphs, # rel_cov_median=total_relative_coverage_median / num_subgraphs, # f=total_full_batch / num_subgraphs, )
def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). seed (int, optional): seed for random number generator for reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N shards (default: 1). shard_id (int, optional): which shard of the data iterator to return (default: 0). num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 0). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ assert isinstance(dataset, FairseqDataset) # initialize the dataset with the correct starting epoch global_start_time = time.time() dataset.set_epoch(epoch) # Horrible hack because fairseq wrapper doesn't set epoch for itself. TODO: take ownership of wrapper dataset.epoch = epoch set_epoch_time = time.time() - global_start_time # get indices ordered by example size start_time = time.time() with numpy_seed('R3LTask', seed, epoch): indices = dataset.ordered_indices() sort_time = time.time() - start_time # create mini-batches with given size constraints start_time = time.time() batch_sampler = np.expand_dims(indices, 1) batch_by_size_time = time.time() - start_time logger.info( 'get batch iterator [seed=%d, epoch=%d, num_shards=%d] is done in %.3f seconds ' '(set epoch=%.3f, sorting=%.3f, batch by size=%.3f)' % ( seed, epoch, num_shards, time.time() - global_start_time, set_epoch_time, sort_time, batch_by_size_time, )) # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, ) return epoch_iter