def forward(self, node):
        """
        Creates a tree where each node's value is the encoded version of the original value.

        :param tree: a tree where each node has a value vector and a list of children
        :return a tuple - (root of encoded tree, cell state)
        """
        value = node.value

        if value is None:
            return (Node(None), self.zero_buffer)

        # List of tuples: (node, cell state)
        children = []

        # Recursively encode children
        for child in node.children:
            encoded_child = self.forward(child)
            children.append(encoded_child)

        # Extract the TreeCell inputs
        inputH = [vec[0].value for vec in children
                  ]  #inputH here is the list of node value of the children
        inputC = [vec[1] for vec in children
                  ]  #inputC isthe list  zero vector for terminal nodes

        for i, hidden in enumerate(inputH):
            if hidden is None:
                inputH[i] = self.zero_buffer

        found = False

        # Feed the inputs into the TreeCell with the appropriate number of children.
        for i in range(len(self.valid_num_children)):
            if self.valid_num_children[i] == len(children):
                newH, newC = self.lstm_list[i](value, inputH, inputC)
                found = True
                break

        if not found:
            print("WHAAAAAT?")
            raise ValueError(
                "Beware.  Something has gone horribly wrong.  You may not have long to"
                " live.")

        # Set our encoded vector as the root of the new tree
        rootNode = Node(newH)
        rootNode.children = [vec[0] for vec in children]
        return (rootNode, newC)
def make_tree_math(json, big_tree=True):
    
    if not big_tree:
        return make_tree_math_short(json)
    
    # Base case for variable names, symbols, or numbers
    base_case = general_base_cases(json)
    if (base_case):
        return base_case
    
    # Base case for empty lists (we just ignore these)
    if json == []:
        return []
    
    parentNode = Node(json["tag"])
    
    # Base case for Nil
    if not "contents" in json:
        return parentNode
    
    children = json["contents"]
    if children != []:
        if type(children) is list:
            parentNode.children.extend(
                map(lambda child: make_tree_math(child), children))
        else:
            single_child = make_tree_math(children)
            if not single_child == []:
                parentNode.children.append(single_child)
    return parentNode
    def forward(self, node):
        """
        Creates a tree where each node's value is the encoded version of the original value.

        :param tree: a tree where each node has a value vector and a list of children
        :return a tuple - (root of encoded tree, cell state)
        """
        value = node.value

        if value is None:
            return (Node(None), self.zero_buffer)

        # List of tuples: (node, cell state)
        children = []

        # Recursively encode children
        for child in node.children:
            encoded_child = self.forward(child)
            children.append(encoded_child)

        # Extract the TreeCell inputs
        inputH = [vec[0].value for vec in children]
        inputC = [vec[1] for vec in children]

        for i, hidden in enumerate(inputH):
            if hidden is None:
                inputH[i] = self.zero_buffer

        if len(children) <= 2:
            newH, newC = self.tree_lstm(
                value, inputH + [self.zero_buffer] * (2 - len(children)),
                inputC + [self.zero_buffer] * (2 - len(children)))
        else:
            print("WHAAAAAT?")
            raise ValueError(
                "Beware.  Something has gone horribly wrong.  You may not have long to"
                " live.")

        # Set our encoded vector as the root of the new tree
        rootNode = Node(newH)
        rootNode.children = [vec[0] for vec in children]
        return (rootNode, newC)
def make_tree_math_short(json):    
    # Base case for variable names, symbols, or numbers
    base_case = general_base_cases(json)
    if (base_case):
        return base_case
    
    value = json["tag"]
    parent_node = Node(value)
    
    # Base case for Nil
    if not "contents" in json:
        return parent_node
    children = json["contents"]
        
    if value == "Digit":
        parent_node = make_tree_math_short(children[1])
        parent_node.children = [make_tree_math_short(children[0])]
        return parent_node

    # Don't include "IntegerM", "VarName", or "Symbol" tokens
    if value in ["IntegerM", "VarName", "Symbol"]:
        return make_tree_math_short(children)

    # Don't use UnOp, UnOpExp, or DoubOp tokens.  Instead make the first child the parent
    if value in ["UnOp", "UnOpExp", "DoubOp"]:
        parent_node = make_tree_math_short(children[0])
        parent_node.children.extend(
                map(lambda child: make_tree_math_short(child), children[1:]))
        return parent_node


    # Don't use PUnOp or BinOp tokens.  Instead mak the second child the parent
    if value in ["PUnOp", "BinOp"]:
        parent_node = make_tree_math_short(children[1])
        parent_node.children.extend(map(lambda child: 
              make_tree_math_short(child), [children[0]] + children[2:]))
        return parent_node

    # For containers, ignore braces entirely.  For the others, make the container the parent
    # and make its content the children.
    if value == "Container":
        firstChild = children[0]
        # Ignore braces since they're invisible
        if firstChild == "LeftBrace":
            return make_tree_math_short(children[1])

        # Otherwise, come up with a name for the container
        name_map = {
            "AbsBar": "Abs",
            "LeftParen": "Parens",
            "Magnitude": "Magnitude",
        }

        container_name = name_map[firstChild]
        parent_node = Node(container_name)
        parent_node.children.append(make_tree_math_short(children[1]))
    else:
        if type(children) is list:
            parent_node.children.extend(
                map(lambda child: make_tree_math_short(child), children))
        else:
            parent_node.children.append(make_tree_math_short(children))
    return parent_node
Beispiel #5
0
    def forward_prediction(self, input_tree, max_size=None):
        """
        Generate an output tree given an input tree
        """
        if max_size is None:
            max_size = self.max_size

        # Encode tree
        annotations, decoder_hiddens, decoder_cell_states = self.encoder(
            input_tree)

        # Counter of how many nodes we've generated so far
        num_nodes = 0

        PLACEHOLDER = 1

        # Output tree we're building up.  The value is a placeholder
        tree = Node(PLACEHOLDER)

        # Create stack of unexpanded nodes
        # Tuple: (hidden_state, cell_state, desired_output, parent_value, child_index)
        #unexpanded = [(decoder_hiddens, decoder_cell_states, tree, self.root_value,None, 0)]
        unexpanded = [(decoder_hiddens, decoder_cell_states, tree,
                       self.root_value, True, None, 0)]

        if self.align_type <= 1:
            attention_hidden_values = self.attention_hidden(annotations)
        else:
            attention_hidden_values = annotations

        # while stack isn't empty:
        while (len(unexpanded)) > 0:
            # Pop last item
            decoder_hiddens, decoder_cell_states, curr_root, parent_val,parent_ec,parent_rel, child_index = \
                unexpanded.pop()

            # Use attention and pass hidden state to make a prediction
            decoder_hiddens = decoder_hiddens[-1].unsqueeze(0)
            attention_logits = self.attention_logits(attention_hidden_values,
                                                     decoder_hiddens)
            attention_probs = self.softmax(
                attention_logits)  # number_of_nodes x 1
            context_vec = (attention_probs * annotations).sum(0).unsqueeze(
                0)  # 1 x hidden_size
            et = self.tanh(
                self.attention_presoftmax(
                    torch.cat((decoder_hiddens, context_vec),
                              dim=1)))  # 1 x hidden_size
            next_input, ec, rel = self.decoder.make_prediction(parent_val, et)
            curr_root.value = next_input
            curr_root.ec = ec
            curr_root.relation = rel
            num_nodes += 1

            # Only generate up to max_size nodes
            if num_nodes > self.max_size:
                break

            # If we have an EOS, there are no children to generate
            if int(curr_root.value) == self.EOS_value:
                continue
            embed = torch.cat()
            decoder_input = torch.cat((self.embedding(next_input), et), 1)
            parent = next_input

            for i in range(20):
                # Get hidden state and cell state which will be used to generate this node's
                # children
                child_hiddens, child_cell_states = self.decoder.get_next(
                    parent, i, decoder_input, decoder_hiddens,
                    decoder_cell_states)
                # Add children to the stack
                curr_child = Node(PLACEHOLDER)
                # parent_states = torch.cat((decoder_hiddens, decoder_cell_states), 0)
                # child_states = torch.cat((child_hiddens, child_cell_states), 0)
                #parent_relation = self.find_relation(parent_states, child_states, attention_hidden_values, annotations, parent, i, return_loss=False)
                # unexpanded = [(decoder_hiddens, decoder_cell_states, tree, self.root_value,True,None, 0)]
                unexpanded.append(
                    (child_hiddens, child_cell_states, curr_child, parent,
                     curr_root.ec, curr_root.relation, i))
                curr_root.children.append(curr_child)

                decoder_hiddens, decoder_cell_states, curr_root, parent_val,parent_ec,parent_rel, child_index  \
                    = child_hiddens, child_cell_states, curr_child, parent,curr_root.ec,curr_root.relation, i
                decoder_hiddens = decoder_hiddens[-1].unsqueeze(0)
                attention_logits = self.attention_logits(
                    attention_hidden_values, decoder_hiddens)
                attention_probs = self.softmax(
                    attention_logits)  # number_of_nodes x 1
                context_vec = (attention_probs * annotations).sum(0).unsqueeze(
                    0)  # 1 x hidden_size
                et = self.tanh(
                    self.attention_presoftmax(
                        torch.cat((decoder_hiddens, context_vec),
                                  dim=1)))  # 1 x hidden_size
                next_input, ec, rel = self.decoder.make_prediction(
                    parent_val, et)

                if ec == 89:
                    break

        return tree