Exemplo n.º 1
0
def unfolded_decoding(W_d, b_d, tree, encoded):
    (n, m) = W_d.shape

    # store all a_e results in tree structure
    decoding_tree = Tree(encoded, [])
    try:
        decoding_tree.span = tree.span
    except:
        pass

    # if the given node (root) has children, decode the node's encoding, split it,
    # and use this as the children's encoding (output) to recurse back, until terminal
    # nodes are reached
    if type(tree) == nltk.tree.Tree and len(tree) > 0:
        decoded = decode(W_d, b_d, encoded)
        for i, child in enumerate(tree):

            # NOTE: the number of branchings n is NOT assumed, but that it is uniform and that
            # len(input layer) = n*len(encoding) IS assumed
            full_decoded = unfolded_decoding(W_d, b_d, child, decoded[i * m : m + (i * m)])
            decoding_tree.append(full_decoded)
        return decoding_tree
    else:
        decoding_tree = Tree(encoded, [])
        try:
            decoding_tree.span = tree.span
        except:
            pass
        return decoding_tree
Exemplo n.º 2
0
def init_tree(tree, bf, cn=0):

    # if it is a tree, might not be non terminal...
    if type(tree) == nltk.tree.Tree:

        # check for correct branching factor
        if len(tree) == bf:

            # check for errors (False return) and recurse up to discard entire tree
            children = []
            for i, t in enumerate(tree):
                initialized_tree = init_tree(t, bf, cn * bf + i)
                if initialized_tree is not False:
                    children.append(initialized_tree)
                else:
                    return False
            out = Tree(tree.node, children)
            out.span = (out[0].span[0], out[-1].span[1])

        # if tree & len = 0, is a terminal node (if not NoneType node)
        elif len(tree) == 0:
            out = Tree(tree.node, [])
            try:
                out.span = (cn * len(out.node), (cn + 1) * len(out.node) - 1)
            except TypeError:
                print "non-terminal ERROR: NoneType node encountered, tree discarded from training set."
                return False

        # if not a terminal node or of correct branching factor, discard
        else:
            print "non-terminal ERROR: wrong branching factor, tree discarded from training set."
            return False

    # if it is not a tree and not None, assume it is a terminal node array, wrap in tree
    elif tree is not None:
        out = Tree(tree, [])
        out.span = ((cn - 1) * len(out.node), cn * len(out.node) - 1)

    # if None though, discard
    else:
        print "non-terminal ERROR: terminal node " + str(cn) + " is NoneType, tree discarded from training set."
        return False
    out.cn = cn
    return out