def get_gold_score(self, label_scores, sen_len, gold):
        accum_score_matrix = label_scores.new_zeros((sen_len, sen_len))

        for length in range(1, sen_len + 1):
            for left in range(0, sen_len + 1 - length):
                right = left + length
                label_score = label_scores[get_position(sen_len, left, right)]

                oracle_label = gold.oracle_label(left, right)
                oracle_label_index = self.vocab.parse_label2id(oracle_label)
                oracle_label_score = label_score[oracle_label_index]

                if length == 1:
                    accum_score_matrix[length - 1, left] = oracle_label_score
                    continue

                oracle_splits = gold.oracle_splits(left, right)
                oracle_split = min(oracle_splits)

                left_score = accum_score_matrix[oracle_split - left - 1, left]
                right_score = accum_score_matrix[right - oracle_split - 1,
                                                 oracle_split]

                accum_score_matrix[length -
                                   1][left] = (left_score + right_score +
                                               oracle_label_score)
        return accum_score_matrix[-1][0]
 def score(self, span_vectors, sen_len, all_span):
     span_vectors = [
         span_vectors[get_position(sen_len, i, j)] for i, j in all_span
     ]
     span_vectors = torch.stack(span_vectors)
     label_scores = self.forward(span_vectors.unsqueeze(0))
     return label_scores.squeeze(0).permute(1, 2, 0)
 def get_gold_labels(self, tree, sen_len):
     span_num = (1 + sen_len) * sen_len // 2
     gold_label = [0] * span_num
     nodes = [tree]
     while nodes:
         node = nodes.pop()
         if isinstance(node, InternalParseNode):
             label = node.label
             labelindex = self.vocab.parse_label2id(label)
             left = node.left
             right = node.right
             gold_label[get_position(sen_len, left, right)] = labelindex
             nodes.extend(reversed(node.children))
     return gold_label
    def CKY(self, label_scores, sen_len, gold=None):
        label_index_matrix = label_scores.new_zeros((sen_len, sen_len),
                                                    dtype=torch.long)
        best_split_matrix = label_scores.new_full((sen_len, sen_len),
                                                  -1,
                                                  dtype=torch.long)
        accum_score_matrix = label_scores.new_zeros((sen_len, sen_len))

        for length in range(1, sen_len + 1):
            for left in range(0, sen_len + 1 - length):
                right = left + length
                label_score = label_scores[get_position(sen_len, left, right)]

                if self.training:
                    oracle_label = gold.oracle_label(left, right)
                    oracle_label_index = self.vocab.parse_label2id(
                        oracle_label)
                    label_score = self.augment(label_score, oracle_label_index)

                if length == sen_len:
                    label_score[0] = float("-inf")

                argmax_label_score, argmax_label_index = torch.max(label_score,
                                                                   dim=0)
                label_index_matrix[length - 1, left] = argmax_label_index

                if length == 1:
                    accum_score_matrix[length - 1, left] = argmax_label_score
                    continue

                span_scores = (accum_score_matrix[range(0, length - 1), left] +
                               accum_score_matrix[range(length - 2, -1, -1),
                                                  range(left + 1, right)])
                accum_score, best_split = torch.max(span_scores, dim=0)
                best_split_matrix[length - 1][left] = best_split + left + 1

                accum_score_matrix[length -
                                   1][left] = accum_score + argmax_label_score
        if self.training:
            return None, accum_score_matrix[-1][0]

        tree = self.trace_back(0, sen_len, best_split_matrix,
                               label_index_matrix)
        assert len(tree) == 1
        return tree[0].convert(), accum_score_matrix[-1][0]
Exemple #5
0
    def CKY(self, label_scores, sen_len):
        label_index_matrix = label_scores.new_zeros(
            (sen_len, sen_len), dtype=torch.long
        )
        best_split_matrix = label_scores.new_full(
            (sen_len, sen_len), -1, dtype=torch.long
        )
        accum_score_matrix = label_scores.new_zeros((sen_len, sen_len))

        for length in range(1, sen_len + 1):
            for left in range(0, sen_len + 1 - length):
                right = left + length

                label_score = label_scores[get_position(sen_len, left, right)]

                if length == sen_len:
                    label_score[0] = float("-inf")

                argmax_label_score, argmax_label_index = torch.max(label_score, dim=0)
                label_index_matrix[length - 1, left] = argmax_label_index

                if length == 1:
                    accum_score_matrix[length - 1, left] = argmax_label_score
                    continue
                span_scores = (
                    accum_score_matrix[range(0, length - 1), left]
                    + accum_score_matrix[
                        range(length - 2, -1, -1), range(left + 1, right)
                    ]
                )
                accum_score, best_split = torch.max(span_scores, dim=0)
                best_split_matrix[length - 1][left] = best_split + left + 1
                if int(argmax_label_index) == self.vocab.NULL_index:
                    accum_score_matrix[length - 1][left] = accum_score
                else:
                    accum_score_matrix[length - 1][left] = (
                        accum_score + argmax_label_score
                    )
        tree = self.trace_back(0, sen_len, best_split_matrix, label_index_matrix)
        assert len(tree) == 1
        return tree[0].convert()
    def helper(self,
               label_scores,
               split_scores,
               sen_len,
               left,
               right,
               gold=None):
        position = get_position(sen_len, left, right)
        label_score = label_scores[position]
        if self.training:
            oracle_label = gold.oracle_label(left, right)
            oracle_label_index = self.vocab.parse_label2id(oracle_label)
            label_score = self.augment(label_score, oracle_label_index)
            oracle_label_score = label_score[oracle_label_index]

        if right - left == sen_len:
            label_score[0] = float("-inf")

        argmax_label_score, argmax_label_index = torch.max(label_score, dim=0)
        argmax_label = self.vocab.id2parse_label(int(argmax_label_index))

        if self.training:
            label = oracle_label
            label_loss = argmax_label_score - oracle_label_score
        else:
            label = argmax_label
            label_loss = label_score[argmax_label_index]

        if right - left == 1:
            tree = LeafParseNode(left, "pos", "word")
            if label:
                tree = InternalParseNode(label, [tree])
            return [tree], label_loss

        left_positions = get_position(sen_len, left, range(left + 1, right))
        right_positions = get_position(sen_len, range(left + 1, right), right)
        splits = split_scores[left_positions] + split_scores[right_positions]

        if self.training:
            oracle_splits = gold.oracle_splits(left, right)
            oracle_split = min(oracle_splits)
            oracle_split_index = oracle_split - (left + 1)
            splits = self.augment(splits, oracle_split_index)
            oracle_split_score = splits[oracle_split_index]

        argmax_split_score, argmax_split_index = torch.max(splits, dim=0)
        argmax_split = argmax_split_index + (left + 1)

        if self.training:
            split = oracle_split
            split_loss = argmax_split_score - oracle_split_score
        else:
            split = argmax_split
            split_loss = splits[argmax_split_index]

        left_trees, left_loss = self.helper(label_scores, split_scores,
                                            sen_len, left, int(split), gold)
        right_trees, right_loss = self.helper(label_scores, split_scores,
                                              sen_len, int(split), right, gold)

        children = left_trees + right_trees
        loss = label_loss + split_loss + left_loss + right_loss
        if label:
            children = [InternalParseNode(label, children)]
        return children, loss