def _prepare_doc(self, curr_doc: Document) -> Dict: """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since data inside same document does not get shuffled. """ ret = {} preprocessed_sents, max_len = [], 0 for curr_sent in curr_doc.raw_sentences(): # TODO: uncased/cased option curr_processed_sent = list(map(lambda s: s.lower().strip(), curr_sent)) + ["<PAD>"] preprocessed_sents.append(curr_processed_sent) if len(curr_processed_sent) > max_len: max_len = len(curr_processed_sent) for i in range(len(preprocessed_sents)): preprocessed_sents[i].extend(["<PAD>"] * (max_len - len(preprocessed_sents[i]))) cluster_sets = [] mention_to_cluster_id = {} for i, curr_cluster in enumerate(curr_doc.clusters): cluster_sets.append(set(curr_cluster)) for mid in curr_cluster: mention_to_cluster_id[mid] = i all_candidate_data = [] for idx_head, (head_id, head_mention) in enumerate(curr_doc.mentions.items(), start=1): gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]] # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`) candidates, candidate_data = [None], [] starts, ends = [], [] candidate_attention = [] correct_antecedents = [] curr_head_data = [[], []] for curr_token in head_mention.tokens: curr_head_data[0].append(curr_token.sentence_index) curr_head_data[1].append(curr_token.position_in_sentence) num_tokens = len(head_mention.tokens) if num_tokens > self.max_span_size: curr_head_data[0] = curr_head_data[0][:self.max_span_size] curr_head_data[1] = curr_head_data[1][:self.max_span_size] else: curr_head_data[0] += [head_mention.tokens[0].sentence_index] * (self.max_span_size - num_tokens) curr_head_data[1] += [-1] * (self.max_span_size - num_tokens) head_start = 0 head_end = num_tokens head_attention = torch.ones((1, self.max_span_size), dtype=torch.bool) head_attention[0, num_tokens:] = False for idx_candidate, (cand_id, cand_mention) in enumerate(curr_doc.mentions.items(), start=1): if idx_candidate >= idx_head: break candidates.append(cand_id) # Maps tokens to positions inside document (idx_sent, idx_inside_sent) for efficient indexing later curr_candidate_data = [[], []] for curr_token in cand_mention.tokens: curr_candidate_data[0].append(curr_token.sentence_index) curr_candidate_data[1].append(curr_token.position_in_sentence) num_tokens = len(cand_mention.tokens) if num_tokens > self.max_span_size: curr_candidate_data[0] = curr_candidate_data[0][:self.max_span_size] curr_candidate_data[1] = curr_candidate_data[1][:self.max_span_size] else: curr_candidate_data[0] += [cand_mention.tokens[0].sentence_index] * (self.max_span_size - num_tokens) curr_candidate_data[1] += [-1] * (self.max_span_size - num_tokens) candidate_data.append(curr_candidate_data) starts.append(0) ends.append(num_tokens) curr_attention = torch.ones((1, self.max_span_size), dtype=torch.bool) curr_attention[0, num_tokens:] = False candidate_attention.append(curr_attention) is_coreferent = cand_id in gt_antecedent_ids if is_coreferent: correct_antecedents.append(idx_candidate) if len(correct_antecedents) == 0: correct_antecedents.append(0) candidate_attention = torch.cat(candidate_attention) if len(candidate_attention) > 0 else [] all_candidate_data.append({ "head_id": head_id, "head_data": torch.tensor([curr_head_data]), "head_attention": head_attention, "head_start": head_start, "head_end": head_end, "candidates": candidates, "candidate_data": torch.tensor(candidate_data), "candidate_attention": candidate_attention, "correct_antecedents": correct_antecedents }) ret["preprocessed_sents"] = preprocessed_sents ret["steps"] = all_candidate_data return ret
def _prepare_doc(self, curr_doc: Document) -> Dict: """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since data inside same document does not get shuffled. """ ret = {} # By default, each sentence is its own segment, meaning sentences are processed independently if self.max_segment_size is None: def get_position(t): return t.sentence_index, t.position_in_sentence _encoded_segments = batch_to_ids(curr_doc.raw_sentences()) # Optionally, one can specify max_segment_size, in which case segments of tokens are processed independently else: def get_position(t): doc_position = t.position_in_document return doc_position // self.max_segment_size, doc_position % self.max_segment_size flattened_doc = list(chain(*curr_doc.raw_sentences())) num_segments = (len(flattened_doc) + self.max_segment_size - 1) // self.max_segment_size _encoded_segments = \ batch_to_ids([flattened_doc[idx_seg * self.max_segment_size: (idx_seg + 1) * self.max_segment_size] for idx_seg in range(num_segments)]) encoded_segments = [] # Convention: Add a PAD word ([0] * max_chars vector) at the end of each segment, for padding mentions for curr_sent in _encoded_segments: encoded_segments.append( torch.cat((curr_sent, torch.zeros( (1, ELMoCharacterMapper.max_word_length), dtype=torch.long)))) encoded_segments = torch.stack(encoded_segments) cluster_sets = [] mention_to_cluster_id = {} for i, curr_cluster in enumerate(curr_doc.clusters): cluster_sets.append(set(curr_cluster)) for mid in curr_cluster: mention_to_cluster_id[mid] = i all_candidate_data = [] for idx_head, (head_id, head_mention) in enumerate(curr_doc.mentions.items(), 1): gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]] # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`) candidates, candidate_data = [None], [] candidate_attention = [] correct_antecedents = [] curr_head_data = [[], []] num_head_words = 0 for curr_token in head_mention.tokens: idx_segment, idx_inside_segment = get_position(curr_token) curr_head_data[0].append(idx_segment) curr_head_data[1].append(idx_inside_segment) num_head_words += 1 if num_head_words > self.max_span_size: curr_head_data[0] = curr_head_data[0][:self.max_span_size] curr_head_data[1] = curr_head_data[1][:self.max_span_size] else: curr_head_data[0] += [curr_head_data[0][-1] ] * (self.max_span_size - num_head_words) curr_head_data[1] += [-1 ] * (self.max_span_size - num_head_words) head_attention = torch.ones((1, self.max_span_size), dtype=torch.bool) head_attention[0, num_head_words:] = False for idx_candidate, (cand_id, cand_mention) in enumerate( curr_doc.mentions.items(), start=1): if idx_candidate >= idx_head: break candidates.append(cand_id) # Maps tokens to positions inside segments (idx_seg, idx_inside_seg) for efficient indexing later curr_candidate_data = [[], []] num_candidate_words = 0 for curr_token in cand_mention.tokens: idx_segment, idx_inside_segment = get_position(curr_token) curr_candidate_data[0].append(idx_segment) curr_candidate_data[1].append(idx_inside_segment) num_candidate_words += 1 if num_candidate_words > self.max_span_size: curr_candidate_data[0] = curr_candidate_data[ 0][:self.max_span_size] curr_candidate_data[1] = curr_candidate_data[ 1][:self.max_span_size] else: # padding tokens index into the PAD token of the last segment curr_candidate_data[0] += [curr_candidate_data[0][-1]] * ( self.max_span_size - num_candidate_words) curr_candidate_data[1] += [-1] * (self.max_span_size - num_candidate_words) candidate_data.append(curr_candidate_data) curr_attention = torch.ones((1, self.max_span_size), dtype=torch.bool) curr_attention[0, num_candidate_words:] = False candidate_attention.append(curr_attention) is_coreferent = cand_id in gt_antecedent_ids if is_coreferent: correct_antecedents.append(idx_candidate) if len(correct_antecedents) == 0: correct_antecedents.append(0) candidate_attention = torch.cat( candidate_attention) if len(candidate_attention) > 0 else [] all_candidate_data.append({ "head_id": head_id, "head_data": torch.tensor([curr_head_data]), "head_attention": head_attention, "candidates": candidates, "candidate_data": torch.tensor(candidate_data), "candidate_attention": candidate_attention, "correct_antecedents": correct_antecedents }) ret["preprocessed_segments"] = encoded_segments ret["steps"] = all_candidate_data return ret