def get_gold_score(self, label_scores, sen_len, gold): accum_score_matrix = label_scores.new_zeros((sen_len, sen_len)) for length in range(1, sen_len + 1): for left in range(0, sen_len + 1 - length): right = left + length label_score = label_scores[get_position(sen_len, left, right)] oracle_label = gold.oracle_label(left, right) oracle_label_index = self.vocab.parse_label2id(oracle_label) oracle_label_score = label_score[oracle_label_index] if length == 1: accum_score_matrix[length - 1, left] = oracle_label_score continue oracle_splits = gold.oracle_splits(left, right) oracle_split = min(oracle_splits) left_score = accum_score_matrix[oracle_split - left - 1, left] right_score = accum_score_matrix[right - oracle_split - 1, oracle_split] accum_score_matrix[length - 1][left] = (left_score + right_score + oracle_label_score) return accum_score_matrix[-1][0]
def score(self, span_vectors, sen_len, all_span): span_vectors = [ span_vectors[get_position(sen_len, i, j)] for i, j in all_span ] span_vectors = torch.stack(span_vectors) label_scores = self.forward(span_vectors.unsqueeze(0)) return label_scores.squeeze(0).permute(1, 2, 0)
def get_gold_labels(self, tree, sen_len): span_num = (1 + sen_len) * sen_len // 2 gold_label = [0] * span_num nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, InternalParseNode): label = node.label labelindex = self.vocab.parse_label2id(label) left = node.left right = node.right gold_label[get_position(sen_len, left, right)] = labelindex nodes.extend(reversed(node.children)) return gold_label
def CKY(self, label_scores, sen_len, gold=None): label_index_matrix = label_scores.new_zeros((sen_len, sen_len), dtype=torch.long) best_split_matrix = label_scores.new_full((sen_len, sen_len), -1, dtype=torch.long) accum_score_matrix = label_scores.new_zeros((sen_len, sen_len)) for length in range(1, sen_len + 1): for left in range(0, sen_len + 1 - length): right = left + length label_score = label_scores[get_position(sen_len, left, right)] if self.training: oracle_label = gold.oracle_label(left, right) oracle_label_index = self.vocab.parse_label2id( oracle_label) label_score = self.augment(label_score, oracle_label_index) if length == sen_len: label_score[0] = float("-inf") argmax_label_score, argmax_label_index = torch.max(label_score, dim=0) label_index_matrix[length - 1, left] = argmax_label_index if length == 1: accum_score_matrix[length - 1, left] = argmax_label_score continue span_scores = (accum_score_matrix[range(0, length - 1), left] + accum_score_matrix[range(length - 2, -1, -1), range(left + 1, right)]) accum_score, best_split = torch.max(span_scores, dim=0) best_split_matrix[length - 1][left] = best_split + left + 1 accum_score_matrix[length - 1][left] = accum_score + argmax_label_score if self.training: return None, accum_score_matrix[-1][0] tree = self.trace_back(0, sen_len, best_split_matrix, label_index_matrix) assert len(tree) == 1 return tree[0].convert(), accum_score_matrix[-1][0]
def CKY(self, label_scores, sen_len): label_index_matrix = label_scores.new_zeros( (sen_len, sen_len), dtype=torch.long ) best_split_matrix = label_scores.new_full( (sen_len, sen_len), -1, dtype=torch.long ) accum_score_matrix = label_scores.new_zeros((sen_len, sen_len)) for length in range(1, sen_len + 1): for left in range(0, sen_len + 1 - length): right = left + length label_score = label_scores[get_position(sen_len, left, right)] if length == sen_len: label_score[0] = float("-inf") argmax_label_score, argmax_label_index = torch.max(label_score, dim=0) label_index_matrix[length - 1, left] = argmax_label_index if length == 1: accum_score_matrix[length - 1, left] = argmax_label_score continue span_scores = ( accum_score_matrix[range(0, length - 1), left] + accum_score_matrix[ range(length - 2, -1, -1), range(left + 1, right) ] ) accum_score, best_split = torch.max(span_scores, dim=0) best_split_matrix[length - 1][left] = best_split + left + 1 if int(argmax_label_index) == self.vocab.NULL_index: accum_score_matrix[length - 1][left] = accum_score else: accum_score_matrix[length - 1][left] = ( accum_score + argmax_label_score ) tree = self.trace_back(0, sen_len, best_split_matrix, label_index_matrix) assert len(tree) == 1 return tree[0].convert()
def helper(self, label_scores, split_scores, sen_len, left, right, gold=None): position = get_position(sen_len, left, right) label_score = label_scores[position] if self.training: oracle_label = gold.oracle_label(left, right) oracle_label_index = self.vocab.parse_label2id(oracle_label) label_score = self.augment(label_score, oracle_label_index) oracle_label_score = label_score[oracle_label_index] if right - left == sen_len: label_score[0] = float("-inf") argmax_label_score, argmax_label_index = torch.max(label_score, dim=0) argmax_label = self.vocab.id2parse_label(int(argmax_label_index)) if self.training: label = oracle_label label_loss = argmax_label_score - oracle_label_score else: label = argmax_label label_loss = label_score[argmax_label_index] if right - left == 1: tree = LeafParseNode(left, "pos", "word") if label: tree = InternalParseNode(label, [tree]) return [tree], label_loss left_positions = get_position(sen_len, left, range(left + 1, right)) right_positions = get_position(sen_len, range(left + 1, right), right) splits = split_scores[left_positions] + split_scores[right_positions] if self.training: oracle_splits = gold.oracle_splits(left, right) oracle_split = min(oracle_splits) oracle_split_index = oracle_split - (left + 1) splits = self.augment(splits, oracle_split_index) oracle_split_score = splits[oracle_split_index] argmax_split_score, argmax_split_index = torch.max(splits, dim=0) argmax_split = argmax_split_index + (left + 1) if self.training: split = oracle_split split_loss = argmax_split_score - oracle_split_score else: split = argmax_split split_loss = splits[argmax_split_index] left_trees, left_loss = self.helper(label_scores, split_scores, sen_len, left, int(split), gold) right_trees, right_loss = self.helper(label_scores, split_scores, sen_len, int(split), right, gold) children = left_trees + right_trees loss = label_loss + split_loss + left_loss + right_loss if label: children = [InternalParseNode(label, children)] return children, loss