Example #1
0
    def test_encode_tree(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        d, m = encoder.encode_tree(action_sequence)

        assert np.array_equal([0, 1, 1], d.numpy())
        assert np.array_equal([[0, 1, 1], [0, 0, 0], [0, 0, 0]], m.numpy())
Example #2
0
    def test_decode(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        expected_action_sequence = ActionSequence()
        expected_action_sequence.eval(ApplyRule(funcdef))
        expected_action_sequence.eval(GenerateToken("", "f"))
        expected_action_sequence.eval(GenerateToken("", "1"))
        expected_action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        result = encoder.decode(
            encoder.encode_action(action_sequence,
                                  [Token(None, "1", "1")])[:-1, 1:],
            [Token(None, "1", "1")])
        assert \
            expected_action_sequence.action_sequence == result.action_sequence
Example #3
0
    def test_encode_invalid_sequence(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        assert encoder.encode_action(action_sequence,
                                     [Token("", "2", "2")]) is None
Example #4
0
    def test_encode_parent(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, -1, -1, -1], [1, 2, 0, 0], [1, 2, 0, 0],
                               [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 1]],
                              parent.numpy())
Example #5
0
    def test_encode_empty_sequence(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, False)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, False),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action = encoder.encode_action(action_sequence, [Token("", "1", "1")])
        parent = encoder.encode_parent(action_sequence)
        d, m = encoder.encode_tree(action_sequence)

        assert np.array_equal([[-1, -1, -1, -1]], action.numpy())
        assert np.array_equal([[-1, -1, -1, -1]], parent.numpy())
        assert np.array_equal(np.zeros((0, )), d.numpy())
        assert np.array_equal(np.zeros((0, 0)), m.numpy())
Example #6
0
 def test_str(self):
     t0 = NodeType("t0", NodeConstraint.Node, False)
     t1 = NodeType("t1", NodeConstraint.Node, False)
     t2 = NodeType("t2", NodeConstraint.Node, True)
     assert "t0 -> [elem0: t1, elem1: t2*]" == \
         str(ExpandTreeRule(t0, [("elem0", t1),
                                 ("elem1", t2)]))
     assert "<close variadic field>" == \
         str(CloseVariadicFieldRule())
Example #7
0
    def test_invalid_close_variadic_field_rule(self):
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(rule))
        with pytest.raises(InvalidActionException):
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
Example #8
0
    def test_str(self):
        t0 = NodeType("t0", NodeConstraint.Node, False)
        t1 = NodeType("t1", NodeConstraint.Node, False)
        t2 = NodeType("t2", NodeConstraint.Node, True)
        assert "Apply (t0 -> [elem0: t1, elem1: t2*])" == \
            str(ApplyRule(
                ExpandTreeRule(t0,
                               [("elem0", t1),
                                ("elem1", t2)])))

        assert "Generate bar:kind" == str(GenerateToken("kind", "bar"))
Example #9
0
 def test_generate_ignore_root_type(self):
     action_sequence = ActionSequence()
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(
                 NodeType(Root(), NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])))
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(NodeType("op", NodeConstraint.Node, False),
                            [])))
     assert Node("op", []) == action_sequence.generate()
Example #10
0
 def test_create_node(self):
     a = Node("def", [Field("name", "literal", Leaf("str", "foo"))])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [
                 ("name", NodeType("literal", NodeConstraint.Token, False))
             ])),
         GenerateToken("str", "foo")
     ] == seq.action_sequence
Example #11
0
    def test_encode_completed_sequence(self):
        none = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                              [])
        encoder = ActionSequenceEncoder(
            Samples([none], [NodeType("value", NodeConstraint.Node, False)],
                    [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(none))
        action = encoder.encode_action(action_sequence, [Token("", "1", "1")])
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, 2, -1, -1], [-1, -1, -1, -1]],
                              action.numpy())
        assert np.array_equal([[-1, -1, -1, -1], [-1, -1, -1, -1]],
                              parent.numpy())
Example #12
0
    def test_encode_path(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("constant", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        path = encoder.encode_path(action_sequence, 2)

        assert np.array_equal(
            np.array(
                [
                    [-1, -1],  # funcdef
                    [2, -1],  # f
                    [2, -1],  # 1
                    [2, -1],  # 2
                    [2, -1],  # CloseVariadicField
                    [2, -1],  # expr
                    [3, 2],  # f
                    [3, 2],  # CloseVariadicField
                    [2, -1],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
        path = encoder.encode_path(action_sequence, 1)
        assert np.array_equal(
            np.array(
                [
                    [-1],  # funcdef
                    [2],  # f
                    [2],  # 1
                    [2],  # 2
                    [2],  # CloseVariadicField
                    [2],  # expr
                    [3],  # f
                    [3],  # CloseVariadicField
                    [2],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
Example #13
0
    def test_decode_invalid_tensor(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, False)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, False),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        assert encoder.decode(torch.LongTensor([[-1, -1, -1]]), []) is None
        assert encoder.decode(torch.LongTensor([[-1, -1, 1]]), []) is None
    def initialize(self, input: Input) -> Environment:
        self.module.encoder.eval()
        state_list = self.transform_input(input)
        state_tensor = self.collate.collate([state_list])
        state_tensor = self._to(state_tensor)
        with torch.no_grad(), logger.block("encode_state"):
            state_tensor = self.module.encoder(state_tensor)
        state = self.collate.split(state_tensor)[0]

        # Add initial rule
        action_sequence = ActionSequence()
        action_sequence.eval(
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])))
        state["action_sequence"] = action_sequence

        return state
Example #15
0
    def test_eval_root(self):
        action_sequence = ActionSequence()
        assert action_sequence.head is None
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(GenerateToken("kind", ""))
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Node, False)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [ApplyRule(rule)] == action_sequence.action_sequence
        assert action_sequence.parent(0) is None
        assert [[], []] == action_sequence._tree.children[0]
Example #16
0
    def test_create_leaf(self):
        seq = ActionSequence.create(Leaf("str", "t0 t1"))
        assert [
            ApplyRule(
                ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [
                    ("root", NodeType(Root(), NodeConstraint.Token, False))
                ])),
            GenerateToken("str", "t0 t1")
        ] == seq.action_sequence

        seq = ActionSequence.create(
            Node(
                "value",
                [Field("name", "str", [Leaf("str", "t0"),
                                       Leaf("str", "t1")])]))
        assert [
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])),
            ApplyRule(
                ExpandTreeRule(
                    NodeType("value", NodeConstraint.Node, False),
                    [("name", NodeType("str", NodeConstraint.Token, True))])),
            GenerateToken("str", "t0"),
            GenerateToken("str", "t1"),
            ApplyRule(CloseVariadicFieldRule())
        ] == seq.action_sequence
Example #17
0
    def test_generate_variadic_token(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(GenerateToken("", "foo"))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                GenerateToken("", "foo")] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        action_sequence.eval(GenerateToken("", "bar"))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1, 2] == action_sequence._tree.children[0][0]
        assert [
            ApplyRule(rule),
            GenerateToken("", "foo"),
            GenerateToken("", "bar")
        ] == action_sequence.action_sequence

        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
        assert [1, 2, 3] == action_sequence._tree.children[0][0]
        assert [
            ApplyRule(rule),
            GenerateToken("", "foo"),
            GenerateToken("", "bar"),
            ApplyRule(CloseVariadicFieldRule())
        ] == action_sequence.action_sequence

        with pytest.raises(InvalidActionException):
            action_sequence.eval(GenerateToken("", "foo"))
Example #18
0
    def test_clone(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("expr", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))

        action_sequence2 = action_sequence.clone()
        assert action_sequence.generate() == action_sequence2.generate()

        action_sequence2.eval(ApplyRule(rule))
        assert \
            action_sequence._tree.children != action_sequence2._tree.children
        assert \
            action_sequence._tree.parent != action_sequence2._tree.parent
        assert \
            action_sequence.action_sequence != action_sequence2.action_sequence
        assert action_sequence._head_action_index != \
            action_sequence2._head_action_index
        assert action_sequence._head_children_index != \
            action_sequence2._head_children_index
        assert action_sequence.generate() != action_sequence2.generate()
Example #19
0
    def test_variadic_field(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0)] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1, 2] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0),
                ApplyRule(rule0)] == action_sequence.action_sequence

        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None

        action_sequence = ActionSequence()
        rule1 = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True)),
             ("name", NodeType("value", NodeConstraint.Node, False))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule1))
        action_sequence.eval(ApplyRule(rule0))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
Example #20
0
    def test_generate(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("name", "f"))
        action_sequence.eval(GenerateToken("name", "_"))
        action_sequence.eval(GenerateToken("name", "0"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("value", "+"))
        action_sequence.eval(GenerateToken("value", "1"))
        action_sequence.eval(GenerateToken("value", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None
        assert Node("def", [
            Field("name", "value",
                  [Leaf("name", "f"),
                   Leaf("name", "_"),
                   Leaf("name", "0")]),
            Field("body", "expr", [
                Node("expr", [
                    Field("op", "value", Leaf("value", "+")),
                    Field("arg0", "value", Leaf("value", "1")),
                    Field("arg1", "value", Leaf("value", "2"))
                ])
            ])
        ]) == action_sequence.generate()
Example #21
0
 def test_create_node_with_variadic_fields(self):
     a = Node(
         "list",
         [Field("elems", "literal",
                [Node("str", []), Node("str", [])])])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [
                 ("elems", NodeType("literal", NodeConstraint.Node, True))
             ])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(CloseVariadicFieldRule())
     ] == seq.action_sequence
Example #22
0
 def test_str(self):
     assert "type" == str(NodeType("type", NodeConstraint.Node, False))
     assert "type*" == str(NodeType("type", NodeConstraint.Node, True))
     assert \
         "type(token)" == str(NodeType("type", NodeConstraint.Token, False))
Example #23
0
 def test_eq(self):
     assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [
         ("f0", NodeType("bar", NodeConstraint.Node, False))]) == \
         ExpandTreeRule(
             NodeType("foo", NodeConstraint.Node, False),
             [("f0", NodeType("bar", NodeConstraint.Node, False))])
     assert GenerateToken("", "foo") == GenerateToken("", "foo")
     assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [
         ("f0", NodeType("bar", NodeConstraint.Node, False))]) != \
         ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [])
     assert GenerateToken("", "foo") != GenerateToken("", "bar")
     assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [
         ("f0", NodeType("bar", NodeConstraint.Node, False))
     ]) != GenerateToken("", "foo")
     assert 0 != ExpandTreeRule(
         NodeType("foo", NodeConstraint.Node, False),
         [("f0", NodeType("bar", NodeConstraint.Node, False))])
Example #24
0
 def test_eq(self):
     assert NodeType("foo", NodeConstraint.Node, False) == \
         NodeType("foo", NodeConstraint.Node, False)
     assert NodeType("foo", NodeConstraint.Node, False) != \
         NodeType("foo", NodeConstraint.Node, True)
     assert 0 != NodeType("foo", NodeConstraint.Node, False)
from math import log
from typing import List

import numpy as np
import torch
import torch.nn as nn

from mlprogram.actions import ExpandTreeRule, NodeConstraint, NodeType
from mlprogram.builtins import Environment
from mlprogram.encoders import ActionSequenceEncoder, Samples
from mlprogram.languages import Root, Token
from mlprogram.samplers import ActionSequenceSampler, SamplerState
from mlprogram.utils.data import Collate, CollateOptions

R = NodeType(Root(), NodeConstraint.Node, False)
X = NodeType("X", NodeConstraint.Node, False)
Y = NodeType("Y", NodeConstraint.Node, False)
Y_list = NodeType("Y", NodeConstraint.Node, True)
Ysub = NodeType("Ysub", NodeConstraint.Node, False)
Str = NodeType("Str", NodeConstraint.Token, True)

Root2X = ExpandTreeRule(R, [("x", X)])
Root2Y = ExpandTreeRule(R, [("y", Y)])
X2Y_list = ExpandTreeRule(X, [("y", Y_list)])
Ysub2Str = ExpandTreeRule(Ysub, [("str", Str)])


def is_subtype(arg0, arg1):
    if arg0 == arg1:
        return True
    if arg0 == "Ysub" and arg1 == "Y":