def test_create_leaf(self): seq = ActionSequence.create(Leaf("str", "t0 t1")) assert [ ApplyRule( ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [ ("root", NodeType(Root(), NodeConstraint.Token, False)) ])), GenerateToken("str", "t0 t1") ] == seq.action_sequence seq = ActionSequence.create( Node( "value", [Field("name", "str", [Leaf("str", "t0"), Leaf("str", "t1")])])) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule( NodeType("value", NodeConstraint.Node, False), [("name", NodeType("str", NodeConstraint.Token, True))])), GenerateToken("str", "t0"), GenerateToken("str", "t1"), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence
def test_generate_ignore_root_type(self): action_sequence = ActionSequence() action_sequence.eval( ApplyRule( ExpandTreeRule( NodeType(Root(), NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))]))) action_sequence.eval( ApplyRule( ExpandTreeRule(NodeType("op", NodeConstraint.Node, False), []))) assert Node("op", []) == action_sequence.generate()
def generate(self) -> AST: """ Generate AST from the action sequence Returns ------- AST The AST corresponding to the action sequence """ def generate(head: int, node_type: Optional[NodeType] = None) -> AST: action = self._action_sequence[head] if isinstance(action, GenerateToken): if action.kind is None: assert node_type is not None assert node_type.type_name is not None return Leaf(node_type.type_name, action.value) else: return Leaf(action.kind, action.value) elif isinstance(action, ApplyRule): # The head action should apply ExpandTreeRule rule = cast(ExpandTreeRule, action.rule) ast = Node(rule.parent.type_name, []) for (name, node_type), actions in zip(rule.children, self._tree.children[head]): assert node_type.type_name is not None if node_type.is_variadic: # Variadic field ast.fields.append(Field(name, node_type.type_name, [])) for act in actions: if isinstance(self._action_sequence[act], ApplyRule): a = cast(ApplyRule, self._action_sequence[act]) if isinstance(a.rule, CloseVariadicFieldRule): break assert isinstance(ast.fields[-1].value, list) ast.fields[-1].value.append( generate(act, node_type)) else: ast.fields.append( Field(name, node_type.type_name, generate(actions[0], node_type))) return ast if len(self.action_sequence) == 0: return generate(0) begin = self.action_sequence[0] if isinstance(begin, ApplyRule) and \ isinstance(begin.rule, ExpandTreeRule): if begin.rule.parent.type_name == Root(): # Ignore Root -> ??? return generate(1) return generate(0) return generate(0)
def test_create_node(self): a = Node("def", [Field("name", "literal", Leaf("str", "foo"))]) seq = ActionSequence.create(a) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [ ("name", NodeType("literal", NodeConstraint.Token, False)) ])), GenerateToken("str", "foo") ] == seq.action_sequence
def initialize(self, input: Input) -> Environment: self.module.encoder.eval() state_list = self.transform_input(input) state_tensor = self.collate.collate([state_list]) state_tensor = self._to(state_tensor) with torch.no_grad(), logger.block("encode_state"): state_tensor = self.module.encoder(state_tensor) state = self.collate.split(state_tensor)[0] # Add initial rule action_sequence = ActionSequence() action_sequence.eval( ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))]))) state["action_sequence"] = action_sequence return state
def test_create_node_with_variadic_fields(self): a = Node( "list", [Field("elems", "literal", [Node("str", []), Node("str", [])])]) seq = ActionSequence.create(a) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [ ("elems", NodeType("literal", NodeConstraint.Node, True)) ])), ApplyRule( ExpandTreeRule(NodeType("str", NodeConstraint.Node, False), [])), ApplyRule( ExpandTreeRule(NodeType("str", NodeConstraint.Node, False), [])), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence
def create(node: AST): """ Return the action sequence corresponding to this AST Parameters ---------- node: AST Returns ------- actionSequence The corresponding action sequence """ def to_sequence(node: AST) -> List[Action]: if isinstance(node, Node): def to_node_type(field: Field) -> NodeType: if isinstance(field.value, list): if len(field.value) > 0 and \ isinstance(field.value[0], Leaf): return NodeType(field.type_name, NodeConstraint.Token, True) else: return NodeType(field.type_name, NodeConstraint.Node, True) else: if isinstance(field.value, Leaf): return NodeType(field.type_name, NodeConstraint.Token, False) else: return NodeType(field.type_name, NodeConstraint.Node, False) children = list( map(lambda f: (f.name, to_node_type(f)), node.fields)) seq: List[Action] = [ ApplyRule( ExpandTreeRule( NodeType(node.type_name, NodeConstraint.Node, False), children)) ] for field in node.fields: if isinstance(field.value, list): for v in field.value: seq.extend(to_sequence(v)) seq.append(ApplyRule(CloseVariadicFieldRule())) else: seq.extend(to_sequence(field.value)) return seq elif isinstance(node, Leaf): node_type = node.get_type_name() assert not isinstance(node_type, Root) assert node_type is not None return [GenerateToken(node_type, node.value)] else: logger.critical(f"Invalid type of node: {type(node)}") raise RuntimeError(f"Invalid type of node: {type(node)}") action_sequence = ActionSequence() node = Node(None, [Field("root", Root(), node)]) for action in to_sequence(node): action_sequence.eval(action) return action_sequence
from math import log from typing import List import numpy as np import torch import torch.nn as nn from mlprogram.actions import ExpandTreeRule, NodeConstraint, NodeType from mlprogram.builtins import Environment from mlprogram.encoders import ActionSequenceEncoder, Samples from mlprogram.languages import Root, Token from mlprogram.samplers import ActionSequenceSampler, SamplerState from mlprogram.utils.data import Collate, CollateOptions R = NodeType(Root(), NodeConstraint.Node, False) X = NodeType("X", NodeConstraint.Node, False) Y = NodeType("Y", NodeConstraint.Node, False) Y_list = NodeType("Y", NodeConstraint.Node, True) Ysub = NodeType("Ysub", NodeConstraint.Node, False) Str = NodeType("Str", NodeConstraint.Token, True) Root2X = ExpandTreeRule(R, [("x", X)]) Root2Y = ExpandTreeRule(R, [("y", Y)]) X2Y_list = ExpandTreeRule(X, [("y", Y_list)]) Ysub2Str = ExpandTreeRule(Ysub, [("str", Str)]) def is_subtype(arg0, arg1): if arg0 == arg1: return True if arg0 == "Ysub" and arg1 == "Y":