Exemple #1
0
 def make_tree():
     nonlocal idx
     idx += 1
     i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
     label = self.label_vocab.value(label_idx)
     if (i + 1) >= j:
         tag, word = sentence[i]
         pred_tag = None
         if pred_pos is not None:
             pred_tag = pred_pos[i]
         tree = trees.LeafParseNode(int(i), tag, word, p_father[i],
                                    self.type_vocab.value(type[i]),
                                    pred_tag)
         if label:
             assert label[0] != Sub_Head
             tree = trees.InternalParseNode(label, [tree])
         return [tree]
     else:
         left_trees = make_tree()
         right_trees = make_tree()
         children = left_trees + right_trees
         if label and label[0] != Sub_Head:
             return [trees.InternalParseNode(label, children)]
         else:
             return children
def distance_to_tree(dist, cons, unary, leaves):
    assert len(dist) == len(leaves) - 1
    assert len(cons) == len(dist)
    assert len(unary) == len(leaves)
    if not len(dist):
        tree = leaves[0]
        if unary[0] != vocabulary.PAD and unary[0] != ():
            tree = trees.InternalParseNode(unary[0], [tree])
    else:
        i = np.argmax(dist)
        tree_l = distance_to_tree(dist[:i], cons[:i], unary[:i + 1],
                                  leaves[:i + 1])
        tree_r = distance_to_tree(dist[i + 1:], cons[i + 1:], unary[i + 1:],
                                  leaves[i + 1:])
        tree = trees.InternalParseNode(cons[i], [tree_l, tree_r])
    return tree
def binarize_tree(tree):
    """Binarizes a tree by choosing the leftmost split point.
  """
    if isinstance(tree, trees.LeafParseNode):
        return tree
    else:
        if len(tree.children) == 1:
            return tree
        elif len(tree.children) == 2:
            left_child = binarize_tree(tree.children[0])
            right_child = binarize_tree(tree.children[1])
        else:
            left_child = binarize_tree(tree.children[0])
            right_child = binarize_tree(
                trees.InternalParseNode((), tree.children[1:]))
        return trees.InternalParseNode(tree.label, [left_child, right_child])
 def helper(tree):
     if isinstance(tree, trees.LeafParseNode):
         return [tree]
     children = []
     for child in tree.children:
         children.extend(helper(child))
     if tree.label:
         return [trees.InternalParseNode(tree.label, children)]
     return children
Exemple #5
0
 def make_tree():
     nonlocal idx
     idx += 1
     i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
     label = self.label_vocab.value(label_idx)
     if (i + 1) >= j:
         tag, word = sentence[i]
         tree = trees.LeafParseNode(int(i), tag, word)
         if label:
             tree = trees.InternalParseNode(label, [tree])
         return [tree]
     else:
         left_trees = make_tree()
         right_trees = make_tree()
         children = left_trees + right_trees
         if label:
             return [trees.InternalParseNode(label, children)]
         else:
             return children
 def make_tree():
     nonlocal idx
     idx += 1
     i, j, _ = span_list[idx]
     lbl = label_list[idx]
     label = label_vocab.value(lbl)
     if (i + 1) >= j:
         tag, word = sent[i]
         tree = trees.LeafParseNode(int(i), tag, word)
         if label:
             tree = trees.InternalParseNode(label, [tree])
         return [tree], j
     else:
         first_trees, jj = make_tree()
         children = first_trees
         while jj < j:
             right_trees, jj = make_tree()
             children = children + right_trees
         if label:
             return [trees.InternalParseNode(label, children)], j
         else:
             return [trees.InternalParseNode(("None", ), children)], j
def debinarize_tree(tree):
    """Debinarizes the tree.
  """
    def helper(tree):
        if isinstance(tree, trees.LeafParseNode):
            return [tree]
        children = []
        for child in tree.children:
            children.extend(helper(child))
        if tree.label:
            return [trees.InternalParseNode(tree.label, children)]
        return children

    nodes = helper(tree)
    if len(nodes) == 1:
        return nodes[0]
    return trees.InternalParseNode(('S', ), nodes)
Exemple #8
0
        def helper(left, right):
            # print("left : ", left , "right", right)
            assert 0 <= left < right <= len(sentence)

            label_scores = self.f_label(get_span_encoding(left, right))
            label_scores.requires_grad_(True)

            if is_train:
                oracle_label = gold.oracle_label(left, right) 
                # gets the correct label for s[left : right]
                oracle_label_index = self.label_vocab.index(oracle_label)
                label_scores = augment(label_scores, oracle_label_index) 

            label_scores_np = label_scores.detach().numpy()
            # npvalue : Returns the value of the expression as a numpy array
            argmax_label_index = int(
                label_scores_np.argmax() if right - left < len(sentence) else
                label_scores_np[1:].argmax() + 1) # this part also me confused
            argmax_label = self.label_vocab.value(argmax_label_index)

            # numpy.argmax(array, axis = None, out = None) : Returns indices of the max 
            # element of the array in a particular axis.

            if is_train:
                label = argmax_label if explore else oracle_label
                label_loss = (
                    label_scores[argmax_label_index] -
                    label_scores[oracle_label_index]
                    if argmax_label != oracle_label else torch.zeros(1))
            else:
                label = argmax_label
                label_loss = label_scores[argmax_label_index]

            if right - left == 1:
                tag, word = sentence[left]
                tree = trees.LeafParseNode(left, tag, word)
                if label:
                    tree = trees.InternalParseNode(label, [tree])
                return [tree], label_loss

            left_encodings = []
            right_encodings = []
            for split in range(left + 1, right):
                left_encodings.append(get_span_encoding(left, split).tolist())
                right_encodings.append(get_span_encoding(split, right).tolist())
            left_scores = torch.tensor([self.f_split(torch.tensor(encoding)).item() for encoding in left_encodings])
            right_scores = torch.tensor([self.f_split(torch.tensor(encoding)).item() for encoding in right_encodings])
            split_scores = left_scores + right_scores
            split_scores.requires_grad_(True)
            # print("split scores : ")
            # print(split_scores)
            if is_train:
                oracle_splits = gold.oracle_splits(left, right)
                oracle_split = min(oracle_splits)
                oracle_split_index = oracle_split - (left + 1)
                split_scores = augment(split_scores, oracle_split_index)

            split_scores_np = split_scores.detach().numpy()
            argmax_split_index = int(split_scores_np.argmax())
            argmax_split = argmax_split_index + (left + 1)
            # print(argmax_split)

            if is_train:
                split = argmax_split if explore else oracle_split
                split_loss = (
                    split_scores[argmax_split_index] -
                    split_scores[oracle_split_index]
                    if argmax_split != oracle_split else torch.zeros(1))
            else:
                split = argmax_split
                split_loss = split_scores[argmax_split_index]

            # print("left = ", left , "split = ", split, "right = ", right)
            left_trees, left_loss = helper(left, split)
            right_trees, right_loss = helper(split, right)

            children = left_trees + right_trees
            if label:
                children = [trees.InternalParseNode(label, children)]

            label_loss.requires_grad_(True)
            split_loss.requires_grad_(True)
            left_loss.requires_grad_(True)
            right_loss.requires_grad_(True)
            return children, label_loss + split_loss + left_loss + right_loss
        def helper(left, right):
            label_scores = self.f_label(get_span_encoding(left, right))
            label_scores.requires_grad_(True)

            if is_train:
                oracle_label = gold.oracle_label(left, right)
                oracle_label_index = self.label_vocab.index(oracle_label)
                label_scores = augment(label_scores, oracle_label_index)

            label_scores_np = label_scores.data.numpy()
            # print("Label scores: {}".format(label_scores_np))
            argmax_label_index = int(
                label_scores_np.argmax() if right -
                left < len(sentence) else label_scores_np[1:].argmax() + 1)
            argmax_label = self.label_vocab.value(argmax_label_index)

            if is_train:
                label = argmax_label if explore else oracle_label
                label_loss = (label_scores[argmax_label_index] -
                              label_scores[oracle_label_index]
                              if argmax_label != oracle_label else Variable(
                                  torch.zeros(1)))
            else:
                label = argmax_label
                label_loss = label_scores[argmax_label_index]

            if right - left == 1:
                tag, word = sentence[left]
                tree = trees.LeafParseNode(left, tag, word)
                if label:
                    tree = trees.InternalParseNode(label, [tree])
                return [tree], label_loss

            left_encodings = []
            right_encodings = []
            for split in range(left + 1, right):
                left_encodings.append(get_span_encoding(left, split))
                right_encodings.append(get_span_encoding(split, right))

            # left_scores = torch.tensor([self.f_split(torch.tensor(encoding)).item() for encoding in left_encodings])
            # right_scores = torch.tensor([self.f_split(torch.tensor(encoding)).item() for encoding in right_encodings])
            # split_scores = left_scores + right_scores
            # split_scores.requires_grad_(True)

            left_scores = self.f_split(torch.stack(left_encodings))
            right_scores = self.f_split(torch.stack(right_encodings))
            split_scores = left_scores + right_scores

            split_scores = split_scores.view(len(left_encodings))
            split_scores.requires_grad_(True)

            #need to check dimensions here
            #             print("(Helper function) Dimension of split encodings left: {}".format(left_encodings[0].size()))
            #             left_scores = self.f_split(left_encodings,1)
            #             right_scores = self.f_split(right_encodings,1)
            #             split_scores = left_scores + right_scores
            #             split_scores = split_scores.view(-1, len(left_encodings), 1)
            # print("(Helper function) Dimension of split scores: {}".format(split_scores.size()))

            if is_train:
                oracle_splits = gold.oracle_splits(left, right)
                oracle_split = min(oracle_splits)
                oracle_split_index = oracle_split - (left + 1)
                split_scores = augment(split_scores, oracle_split_index)

            split_scores_np = split_scores.data.numpy()
            argmax_split_index = int(split_scores_np.argmax())
            argmax_split = argmax_split_index + (left + 1)

            if is_train:
                split = argmax_split if explore else oracle_split
                split_loss = (split_scores[argmax_split_index] -
                              split_scores[oracle_split_index]
                              if argmax_split != oracle_split else Variable(
                                  torch.zeros(1)))
            else:
                split = argmax_split
                split_loss = split_scores[argmax_split_index]

            left_trees, left_loss = helper(left, split)
            right_trees, right_loss = helper(split, right)

            children = left_trees + right_trees
            if label:
                children = [trees.InternalParseNode(label, children)]

            return children, label_loss + split_loss + left_loss + right_loss
Exemple #10
0
        def helper(force_gold):
            if force_gold:
                assert is_train

            chart = {}
            label_scores_span_max = {}

            for length in range(1, len(sentence) + 1):
                for left in range(0, len(sentence) + 1 - length):
                    right = left + length

                    label_scores = get_label_scores(left, right)

                    if is_train:
                        oracle_label = gold.oracle_label(left, right)
                        oracle_label_index = self.label_vocab.index(
                            oracle_label)

                    if force_gold:
                        label = oracle_label
                        label_index = oracle_label_index
                        label_score = label_scores[label_index]

                        if self.nontlabelstyle == 3:
                            label_scores_np = label_scores.npvalue()
                            argmax_label_index = int(label_scores_np.argmax(
                            ) if length < len(
                                sentence) else label_scores_np[1:].argmax() +
                                                     1)
                            argmax_label = self.label_vocab.value(
                                argmax_label_index)
                            label = argmax_label
                            label_score = label_scores[argmax_label_index]

                    else:
                        if is_train:
                            label_scores = augment(label_scores,
                                                   oracle_label_index)
                        label_scores_np = label_scores.npvalue()
                        #argmax_score = dy.argmax(label_scores, gradient_mode="straight_through_gradient")
                        #dy.dot_product()
                        argmax_label_index = int(
                            label_scores_np.argmax() if length < len(sentence)
                            else label_scores_np[1:].argmax() + 1)
                        argmax_label = self.label_vocab.value(
                            argmax_label_index)
                        label = argmax_label
                        label_score = label_scores[argmax_label_index]

                    if length == 1:
                        tag, word = sentence[left]
                        tree = trees.LeafParseNode(left, tag, word)
                        if label:
                            tree = trees.InternalParseNode(label, [tree])
                        chart[left, right] = [tree], label_score
                        label_scores_span_max[left,
                                              right] = [tree], label_score
                        continue

                    if force_gold:
                        oracle_splits = gold.oracle_splits(left, right)

                        if (
                                len(label) > 0 and
                            (label[0].endswith("'") or label[0] == EMPTY)
                        ) and latentscope[0] <= left <= right <= latentscope[
                                1]:  #  if label == (EMPTY,) and latentscope[0] <= left <= right <= latentscope[1]: # and label != ():
                            # Latent during Training
                            #if self.train_constraint:

                            # if self.train_constraint:
                            #     oracle_splits = [(oracle_splits[0], 0), (oracle_splits[-1], 1)]
                            # else:
                            #     #oracle_splits = [(p, 0) for p in oracle_splits] + [(p, 1) for p in oracle_splits]  # it is not correct
                            #     pass

                            oracle_splits = [(oracle_splits[0], 0),
                                             (oracle_splits[-1], 1)]
                            best_split = max(
                                oracle_splits,
                                key=lambda
                                sb:  #(split, branching)  #branching == 0: right branching;  1: left branching
                                label_scores_span_max[left, sb[0]][1].value(
                                ) + chart[sb[0], right][1].value()
                                if sb[1] == 0 else chart[left, sb[0]][1].value(
                                ) + label_scores_span_max[sb[0], right][
                                    1].value())
                        else:

                            best_split = (min(oracle_splits), 0
                                          )  #by default right braching

                    else:
                        pred_range = range(left + 1, right)
                        pred_splits = [(p, 0) for p in pred_range
                                       ] + [(p, 1) for p in pred_range]
                        best_split = max(
                            pred_splits,
                            key=lambda
                            sb:  # (split, branching)  #branching == 0: right branching;  1: left branching
                            label_scores_span_max[left, sb[0]][1].value(
                            ) + chart[sb[0], right][1].value()
                            if sb[1] == 0 else chart[left, sb[0]][1].value(
                            ) + label_scores_span_max[sb[0], right][1].value())

                    children_leaf = [
                        trees.LeafParseNode(pos, sentence[pos][0],
                                            sentence[pos][1])
                        for pos in range(left, right)
                    ]
                    label_scores_span_max[left,
                                          right] = children_leaf, label_score

                    if best_split[1] == 0:  #Right Branching
                        left_trees, left_score = label_scores_span_max[
                            left, best_split[0]]
                        right_trees, right_score = chart[best_split[0], right]
                    else:  #Left Branching
                        left_trees, left_score = chart[left, best_split[0]]
                        right_trees, right_score = label_scores_span_max[
                            best_split[0], right]

                    children = left_trees + right_trees

                    if label:
                        children = [trees.InternalParseNode(label, children)]
                        if not label[0].endswith("'"):
                            children_leaf = [
                                trees.InternalParseNode(label, children_leaf)
                            ]
                            label_scores_span_max[
                                left, right] = children_leaf, label_score

                    chart[left,
                          right] = (children,
                                    label_score + left_score + right_score)

            children, score = chart[0, len(sentence)]
            assert len(children) == 1
            return children[0], score
Exemple #11
0
        def helper(left, right):
            assert 0 <= left < right <= len(sentence)

            label_scores = self.f_label(get_span_encoding(left, right))

            if is_train:
                oracle_label = gold.oracle_label(left, right)
                oracle_label_index = self.label_vocab.index(oracle_label)
                label_scores = augment(label_scores, oracle_label_index)

            # label_scores_np = label_scores.npvalue()
            label_scores_np = label_scores.data.numpy()
            argmax_label_index = int(
                label_scores_np.argmax() if right -
                left < len(sentence) else label_scores_np[1:].argmax() + 1)
            argmax_label = self.label_vocab.value(argmax_label_index)

            if is_train:
                label = argmax_label if explore else oracle_label
                label_loss = (
                    label_scores[argmax_label_index] -
                    label_scores[oracle_label_index]
                    # if argmax_label != oracle_label else dy.zeros(1))
                    if argmax_label != oracle_label else Variable(
                        torch.zeros(1)))
            else:
                label = argmax_label
                label_loss = label_scores[argmax_label_index]

            if right - left == 1:
                tag, word = sentence[left]
                tree = trees.LeafParseNode(left, tag, word)
                if label:
                    tree = trees.InternalParseNode(label, [tree])
                return [tree], label_loss

            left_encodings = []
            right_encodings = []
            for split in range(left + 1, right):
                left_encodings.append(get_span_encoding(left, split))
                right_encodings.append(get_span_encoding(split, right))
            # left_scores = self.f_split(dy.concatenate_to_batch(left_encodings))
            # right_scores = self.f_split(dy.concatenate_to_batch(right_encodings))
            left_scores = self.f_split(torch.stack(left_encodings))
            right_scores = self.f_split(torch.stack(right_encodings))
            split_scores = left_scores + right_scores
            # split_scores = dy.reshape(split_scores, (len(left_encodings),))
            split_scores = split_scores.view(len(left_encodings))

            if is_train:
                oracle_splits = gold.oracle_splits(left, right)
                oracle_split = min(oracle_splits)
                oracle_split_index = oracle_split - (left + 1)
                split_scores = augment(split_scores, oracle_split_index)

            # split_scores_np = split_scores.npvalue()
            split_scores_np = split_scores.data.numpy()
            argmax_split_index = int(split_scores_np.argmax())
            argmax_split = argmax_split_index + (left + 1)

            if is_train:
                split = argmax_split if explore else oracle_split
                split_loss = (
                    split_scores[argmax_split_index] -
                    split_scores[oracle_split_index]
                    # if argmax_split != oracle_split else dy.zeros(1))
                    if argmax_split != oracle_split else Variable(
                        torch.zeros(1)))
            else:
                split = argmax_split
                split_loss = split_scores[argmax_split_index]

            left_trees, left_loss = helper(left, split)
            right_trees, right_loss = helper(split, right)

            children = left_trees + right_trees
            if label:
                children = [trees.InternalParseNode(label, children)]

            return children, label_loss + split_loss + left_loss + right_loss
Exemple #12
0
        def helper(force_gold):
            if force_gold:
                assert is_train

            chart = {}

            for length in range(1, len(sentence) + 1):
                for left in range(0, len(sentence) + 1 - length):
                    right = left + length

                    label_scores = get_label_scores(left, right)

                    if is_train:
                        oracle_label = gold.oracle_label(left, right)
                        oracle_label_index = self.label_vocab.index(oracle_label)

                    if force_gold:
                        label = oracle_label
                        label_score = label_scores[oracle_label_index]
                    else:
                        if is_train:
                            label_scores = augment(label_scores, oracle_label_index)
                        label_scores_np = label_scores.npvalue()
                        argmax_label_index = int(
                            label_scores_np.argmax() if length < len(sentence) else
                            label_scores_np[1:].argmax() + 1)
                        argmax_label = self.label_vocab.value(argmax_label_index)
                        label = argmax_label
                        label_score = label_scores[argmax_label_index]

                    if length == 1:
                        tag, word = sentence[left]
                        tree = trees.LeafParseNode(left, tag, word)
                        if label:
                            tree = trees.InternalParseNode(label, [tree])
                        chart[left, right] = [tree], label_score
                        continue

                    if force_gold:
                        oracle_splits = gold.oracle_splits(left, right)
                        oracle_split = min(oracle_splits)
                        best_split = oracle_split
                    else:
                        best_split = max(
                            range(left + 1, right),
                            key=lambda split:
                                chart[left, split][1].value() +
                                chart[split, right][1].value())

                    left_trees, left_score = chart[left, best_split]
                    right_trees, right_score = chart[best_split, right]

                    children = left_trees + right_trees
                    if label:
                        children = [trees.InternalParseNode(label, children)]

                    chart[left, right] = (
                        children, label_score + left_score + right_score)

            children, score = chart[0, len(sentence)]
            assert len(children) == 1
            return children[0], score