示例#1
0
    def test_create_leaf(self):
        seq = ActionSequence.create(Leaf("str", "t0 t1"))
        assert [
            ApplyRule(
                ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [
                    ("root", NodeType(Root(), NodeConstraint.Token, False))
                ])),
            GenerateToken("str", "t0 t1")
        ] == seq.action_sequence

        seq = ActionSequence.create(
            Node(
                "value",
                [Field("name", "str", [Leaf("str", "t0"),
                                       Leaf("str", "t1")])]))
        assert [
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])),
            ApplyRule(
                ExpandTreeRule(
                    NodeType("value", NodeConstraint.Node, False),
                    [("name", NodeType("str", NodeConstraint.Token, True))])),
            GenerateToken("str", "t0"),
            GenerateToken("str", "t1"),
            ApplyRule(CloseVariadicFieldRule())
        ] == seq.action_sequence
示例#2
0
 def test_generate_ignore_root_type(self):
     action_sequence = ActionSequence()
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(
                 NodeType(Root(), NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])))
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(NodeType("op", NodeConstraint.Node, False),
                            [])))
     assert Node("op", []) == action_sequence.generate()
示例#3
0
    def generate(self) -> AST:
        """
        Generate AST from the action sequence

        Returns
        -------
        AST
            The AST corresponding to the action sequence
        """
        def generate(head: int, node_type: Optional[NodeType] = None) -> AST:
            action = self._action_sequence[head]
            if isinstance(action, GenerateToken):
                if action.kind is None:
                    assert node_type is not None
                    assert node_type.type_name is not None
                    return Leaf(node_type.type_name, action.value)
                else:
                    return Leaf(action.kind, action.value)
            elif isinstance(action, ApplyRule):
                # The head action should apply ExpandTreeRule
                rule = cast(ExpandTreeRule, action.rule)

                ast = Node(rule.parent.type_name, [])
                for (name,
                     node_type), actions in zip(rule.children,
                                                self._tree.children[head]):
                    assert node_type.type_name is not None
                    if node_type.is_variadic:
                        # Variadic field
                        ast.fields.append(Field(name, node_type.type_name, []))
                        for act in actions:
                            if isinstance(self._action_sequence[act],
                                          ApplyRule):
                                a = cast(ApplyRule, self._action_sequence[act])
                                if isinstance(a.rule, CloseVariadicFieldRule):
                                    break
                            assert isinstance(ast.fields[-1].value, list)
                            ast.fields[-1].value.append(
                                generate(act, node_type))
                    else:
                        ast.fields.append(
                            Field(name, node_type.type_name,
                                  generate(actions[0], node_type)))
                return ast

        if len(self.action_sequence) == 0:
            return generate(0)
        begin = self.action_sequence[0]
        if isinstance(begin, ApplyRule) and \
                isinstance(begin.rule, ExpandTreeRule):
            if begin.rule.parent.type_name == Root():
                # Ignore Root -> ???
                return generate(1)
            return generate(0)
        return generate(0)
示例#4
0
 def test_create_node(self):
     a = Node("def", [Field("name", "literal", Leaf("str", "foo"))])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [
                 ("name", NodeType("literal", NodeConstraint.Token, False))
             ])),
         GenerateToken("str", "foo")
     ] == seq.action_sequence
    def initialize(self, input: Input) -> Environment:
        self.module.encoder.eval()
        state_list = self.transform_input(input)
        state_tensor = self.collate.collate([state_list])
        state_tensor = self._to(state_tensor)
        with torch.no_grad(), logger.block("encode_state"):
            state_tensor = self.module.encoder(state_tensor)
        state = self.collate.split(state_tensor)[0]

        # Add initial rule
        action_sequence = ActionSequence()
        action_sequence.eval(
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])))
        state["action_sequence"] = action_sequence

        return state
示例#6
0
 def test_create_node_with_variadic_fields(self):
     a = Node(
         "list",
         [Field("elems", "literal",
                [Node("str", []), Node("str", [])])])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [
                 ("elems", NodeType("literal", NodeConstraint.Node, True))
             ])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(CloseVariadicFieldRule())
     ] == seq.action_sequence
示例#7
0
    def create(node: AST):
        """
        Return the action sequence corresponding to this AST

        Parameters
        ----------
        node: AST

        Returns
        -------
        actionSequence
            The corresponding action sequence
        """
        def to_sequence(node: AST) -> List[Action]:
            if isinstance(node, Node):

                def to_node_type(field: Field) -> NodeType:
                    if isinstance(field.value, list):
                        if len(field.value) > 0 and \
                                isinstance(field.value[0], Leaf):
                            return NodeType(field.type_name,
                                            NodeConstraint.Token, True)
                        else:
                            return NodeType(field.type_name,
                                            NodeConstraint.Node, True)
                    else:
                        if isinstance(field.value, Leaf):
                            return NodeType(field.type_name,
                                            NodeConstraint.Token, False)
                        else:
                            return NodeType(field.type_name,
                                            NodeConstraint.Node, False)

                children = list(
                    map(lambda f: (f.name, to_node_type(f)), node.fields))

                seq: List[Action] = [
                    ApplyRule(
                        ExpandTreeRule(
                            NodeType(node.type_name, NodeConstraint.Node,
                                     False), children))
                ]
                for field in node.fields:
                    if isinstance(field.value, list):
                        for v in field.value:
                            seq.extend(to_sequence(v))
                        seq.append(ApplyRule(CloseVariadicFieldRule()))
                    else:
                        seq.extend(to_sequence(field.value))
                return seq
            elif isinstance(node, Leaf):
                node_type = node.get_type_name()
                assert not isinstance(node_type, Root)
                assert node_type is not None
                return [GenerateToken(node_type, node.value)]
            else:
                logger.critical(f"Invalid type of node: {type(node)}")
                raise RuntimeError(f"Invalid type of node: {type(node)}")

        action_sequence = ActionSequence()
        node = Node(None, [Field("root", Root(), node)])
        for action in to_sequence(node):
            action_sequence.eval(action)
        return action_sequence
from math import log
from typing import List

import numpy as np
import torch
import torch.nn as nn

from mlprogram.actions import ExpandTreeRule, NodeConstraint, NodeType
from mlprogram.builtins import Environment
from mlprogram.encoders import ActionSequenceEncoder, Samples
from mlprogram.languages import Root, Token
from mlprogram.samplers import ActionSequenceSampler, SamplerState
from mlprogram.utils.data import Collate, CollateOptions

R = NodeType(Root(), NodeConstraint.Node, False)
X = NodeType("X", NodeConstraint.Node, False)
Y = NodeType("Y", NodeConstraint.Node, False)
Y_list = NodeType("Y", NodeConstraint.Node, True)
Ysub = NodeType("Ysub", NodeConstraint.Node, False)
Str = NodeType("Str", NodeConstraint.Token, True)

Root2X = ExpandTreeRule(R, [("x", X)])
Root2Y = ExpandTreeRule(R, [("y", Y)])
X2Y_list = ExpandTreeRule(X, [("y", Y_list)])
Ysub2Str = ExpandTreeRule(Ysub, [("str", Str)])


def is_subtype(arg0, arg1):
    if arg0 == arg1:
        return True
    if arg0 == "Ysub" and arg1 == "Y":