コード例 #1
0
    def test_encode_tree(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        d, m = encoder.encode_tree(action_sequence)

        assert np.array_equal([0, 1, 1], d.numpy())
        assert np.array_equal([[0, 1, 1], [0, 0, 0], [0, 0, 0]], m.numpy())
コード例 #2
0
    def test_decode(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        expected_action_sequence = ActionSequence()
        expected_action_sequence.eval(ApplyRule(funcdef))
        expected_action_sequence.eval(GenerateToken("", "f"))
        expected_action_sequence.eval(GenerateToken("", "1"))
        expected_action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        result = encoder.decode(
            encoder.encode_action(action_sequence,
                                  [Token(None, "1", "1")])[:-1, 1:],
            [Token(None, "1", "1")])
        assert \
            expected_action_sequence.action_sequence == result.action_sequence
コード例 #3
0
    def test_encode_invalid_sequence(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        assert encoder.encode_action(action_sequence,
                                     [Token("", "2", "2")]) is None
コード例 #4
0
    def test_encode_parent(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, -1, -1, -1], [1, 2, 0, 0], [1, 2, 0, 0],
                               [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 1]],
                              parent.numpy())
コード例 #5
0
    def test_encode_empty_sequence(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, False)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, False),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action = encoder.encode_action(action_sequence, [Token("", "1", "1")])
        parent = encoder.encode_parent(action_sequence)
        d, m = encoder.encode_tree(action_sequence)

        assert np.array_equal([[-1, -1, -1, -1]], action.numpy())
        assert np.array_equal([[-1, -1, -1, -1]], parent.numpy())
        assert np.array_equal(np.zeros((0, )), d.numpy())
        assert np.array_equal(np.zeros((0, 0)), m.numpy())
コード例 #6
0
ファイル: test_action.py プロジェクト: nashid/mlprogram
 def test_str(self):
     t0 = NodeType("t0", NodeConstraint.Node, False)
     t1 = NodeType("t1", NodeConstraint.Node, False)
     t2 = NodeType("t2", NodeConstraint.Node, True)
     assert "t0 -> [elem0: t1, elem1: t2*]" == \
         str(ExpandTreeRule(t0, [("elem0", t1),
                                 ("elem1", t2)]))
     assert "<close variadic field>" == \
         str(CloseVariadicFieldRule())
コード例 #7
0
    def test_invalid_close_variadic_field_rule(self):
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(rule))
        with pytest.raises(InvalidActionException):
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
コード例 #8
0
ファイル: test_action.py プロジェクト: nashid/mlprogram
    def test_str(self):
        t0 = NodeType("t0", NodeConstraint.Node, False)
        t1 = NodeType("t1", NodeConstraint.Node, False)
        t2 = NodeType("t2", NodeConstraint.Node, True)
        assert "Apply (t0 -> [elem0: t1, elem1: t2*])" == \
            str(ApplyRule(
                ExpandTreeRule(t0,
                               [("elem0", t1),
                                ("elem1", t2)])))

        assert "Generate bar:kind" == str(GenerateToken("kind", "bar"))
コード例 #9
0
 def test_generate_ignore_root_type(self):
     action_sequence = ActionSequence()
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(
                 NodeType(Root(), NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])))
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(NodeType("op", NodeConstraint.Node, False),
                            [])))
     assert Node("op", []) == action_sequence.generate()
コード例 #10
0
 def test_create_node(self):
     a = Node("def", [Field("name", "literal", Leaf("str", "foo"))])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [
                 ("name", NodeType("literal", NodeConstraint.Token, False))
             ])),
         GenerateToken("str", "foo")
     ] == seq.action_sequence
コード例 #11
0
    def test_encode_completed_sequence(self):
        none = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                              [])
        encoder = ActionSequenceEncoder(
            Samples([none], [NodeType("value", NodeConstraint.Node, False)],
                    [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(none))
        action = encoder.encode_action(action_sequence, [Token("", "1", "1")])
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, 2, -1, -1], [-1, -1, -1, -1]],
                              action.numpy())
        assert np.array_equal([[-1, -1, -1, -1], [-1, -1, -1, -1]],
                              parent.numpy())
コード例 #12
0
    def test_encode_path(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("constant", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        path = encoder.encode_path(action_sequence, 2)

        assert np.array_equal(
            np.array(
                [
                    [-1, -1],  # funcdef
                    [2, -1],  # f
                    [2, -1],  # 1
                    [2, -1],  # 2
                    [2, -1],  # CloseVariadicField
                    [2, -1],  # expr
                    [3, 2],  # f
                    [3, 2],  # CloseVariadicField
                    [2, -1],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
        path = encoder.encode_path(action_sequence, 1)
        assert np.array_equal(
            np.array(
                [
                    [-1],  # funcdef
                    [2],  # f
                    [2],  # 1
                    [2],  # 2
                    [2],  # CloseVariadicField
                    [2],  # expr
                    [3],  # f
                    [3],  # CloseVariadicField
                    [2],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
コード例 #13
0
    def test_decode_invalid_tensor(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, False)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, False),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        assert encoder.decode(torch.LongTensor([[-1, -1, -1]]), []) is None
        assert encoder.decode(torch.LongTensor([[-1, -1, 1]]), []) is None
コード例 #14
0
    def initialize(self, input: Input) -> Environment:
        self.module.encoder.eval()
        state_list = self.transform_input(input)
        state_tensor = self.collate.collate([state_list])
        state_tensor = self._to(state_tensor)
        with torch.no_grad(), logger.block("encode_state"):
            state_tensor = self.module.encoder(state_tensor)
        state = self.collate.split(state_tensor)[0]

        # Add initial rule
        action_sequence = ActionSequence()
        action_sequence.eval(
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])))
        state["action_sequence"] = action_sequence

        return state
コード例 #15
0
    def test_eval_root(self):
        action_sequence = ActionSequence()
        assert action_sequence.head is None
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(GenerateToken("kind", ""))
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Node, False)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [ApplyRule(rule)] == action_sequence.action_sequence
        assert action_sequence.parent(0) is None
        assert [[], []] == action_sequence._tree.children[0]
コード例 #16
0
    def test_create_leaf(self):
        seq = ActionSequence.create(Leaf("str", "t0 t1"))
        assert [
            ApplyRule(
                ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [
                    ("root", NodeType(Root(), NodeConstraint.Token, False))
                ])),
            GenerateToken("str", "t0 t1")
        ] == seq.action_sequence

        seq = ActionSequence.create(
            Node(
                "value",
                [Field("name", "str", [Leaf("str", "t0"),
                                       Leaf("str", "t1")])]))
        assert [
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])),
            ApplyRule(
                ExpandTreeRule(
                    NodeType("value", NodeConstraint.Node, False),
                    [("name", NodeType("str", NodeConstraint.Token, True))])),
            GenerateToken("str", "t0"),
            GenerateToken("str", "t1"),
            ApplyRule(CloseVariadicFieldRule())
        ] == seq.action_sequence
コード例 #17
0
    def test_generate_variadic_token(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(GenerateToken("", "foo"))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                GenerateToken("", "foo")] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        action_sequence.eval(GenerateToken("", "bar"))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1, 2] == action_sequence._tree.children[0][0]
        assert [
            ApplyRule(rule),
            GenerateToken("", "foo"),
            GenerateToken("", "bar")
        ] == action_sequence.action_sequence

        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
        assert [1, 2, 3] == action_sequence._tree.children[0][0]
        assert [
            ApplyRule(rule),
            GenerateToken("", "foo"),
            GenerateToken("", "bar"),
            ApplyRule(CloseVariadicFieldRule())
        ] == action_sequence.action_sequence

        with pytest.raises(InvalidActionException):
            action_sequence.eval(GenerateToken("", "foo"))
コード例 #18
0
    def test_clone(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("expr", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))

        action_sequence2 = action_sequence.clone()
        assert action_sequence.generate() == action_sequence2.generate()

        action_sequence2.eval(ApplyRule(rule))
        assert \
            action_sequence._tree.children != action_sequence2._tree.children
        assert \
            action_sequence._tree.parent != action_sequence2._tree.parent
        assert \
            action_sequence.action_sequence != action_sequence2.action_sequence
        assert action_sequence._head_action_index != \
            action_sequence2._head_action_index
        assert action_sequence._head_children_index != \
            action_sequence2._head_children_index
        assert action_sequence.generate() != action_sequence2.generate()
コード例 #19
0
    def test_variadic_field(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0)] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1, 2] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0),
                ApplyRule(rule0)] == action_sequence.action_sequence

        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None

        action_sequence = ActionSequence()
        rule1 = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True)),
             ("name", NodeType("value", NodeConstraint.Node, False))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule1))
        action_sequence.eval(ApplyRule(rule0))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
コード例 #20
0
    def test_generate(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("name", "f"))
        action_sequence.eval(GenerateToken("name", "_"))
        action_sequence.eval(GenerateToken("name", "0"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("value", "+"))
        action_sequence.eval(GenerateToken("value", "1"))
        action_sequence.eval(GenerateToken("value", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None
        assert Node("def", [
            Field("name", "value",
                  [Leaf("name", "f"),
                   Leaf("name", "_"),
                   Leaf("name", "0")]),
            Field("body", "expr", [
                Node("expr", [
                    Field("op", "value", Leaf("value", "+")),
                    Field("arg0", "value", Leaf("value", "1")),
                    Field("arg1", "value", Leaf("value", "2"))
                ])
            ])
        ]) == action_sequence.generate()
コード例 #21
0
 def test_create_node_with_variadic_fields(self):
     a = Node(
         "list",
         [Field("elems", "literal",
                [Node("str", []), Node("str", [])])])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [
                 ("elems", NodeType("literal", NodeConstraint.Node, True))
             ])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(CloseVariadicFieldRule())
     ] == seq.action_sequence
コード例 #22
0
ファイル: test_action.py プロジェクト: nashid/mlprogram
 def test_str(self):
     assert "type" == str(NodeType("type", NodeConstraint.Node, False))
     assert "type*" == str(NodeType("type", NodeConstraint.Node, True))
     assert \
         "type(token)" == str(NodeType("type", NodeConstraint.Token, False))
コード例 #23
0
ファイル: test_action.py プロジェクト: nashid/mlprogram
 def test_eq(self):
     assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [
         ("f0", NodeType("bar", NodeConstraint.Node, False))]) == \
         ExpandTreeRule(
             NodeType("foo", NodeConstraint.Node, False),
             [("f0", NodeType("bar", NodeConstraint.Node, False))])
     assert GenerateToken("", "foo") == GenerateToken("", "foo")
     assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [
         ("f0", NodeType("bar", NodeConstraint.Node, False))]) != \
         ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [])
     assert GenerateToken("", "foo") != GenerateToken("", "bar")
     assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [
         ("f0", NodeType("bar", NodeConstraint.Node, False))
     ]) != GenerateToken("", "foo")
     assert 0 != ExpandTreeRule(
         NodeType("foo", NodeConstraint.Node, False),
         [("f0", NodeType("bar", NodeConstraint.Node, False))])
コード例 #24
0
ファイル: test_action.py プロジェクト: nashid/mlprogram
 def test_eq(self):
     assert NodeType("foo", NodeConstraint.Node, False) == \
         NodeType("foo", NodeConstraint.Node, False)
     assert NodeType("foo", NodeConstraint.Node, False) != \
         NodeType("foo", NodeConstraint.Node, True)
     assert 0 != NodeType("foo", NodeConstraint.Node, False)
コード例 #25
0
from math import log
from typing import List

import numpy as np
import torch
import torch.nn as nn

from mlprogram.actions import ExpandTreeRule, NodeConstraint, NodeType
from mlprogram.builtins import Environment
from mlprogram.encoders import ActionSequenceEncoder, Samples
from mlprogram.languages import Root, Token
from mlprogram.samplers import ActionSequenceSampler, SamplerState
from mlprogram.utils.data import Collate, CollateOptions

R = NodeType(Root(), NodeConstraint.Node, False)
X = NodeType("X", NodeConstraint.Node, False)
Y = NodeType("Y", NodeConstraint.Node, False)
Y_list = NodeType("Y", NodeConstraint.Node, True)
Ysub = NodeType("Ysub", NodeConstraint.Node, False)
Str = NodeType("Str", NodeConstraint.Token, True)

Root2X = ExpandTreeRule(R, [("x", X)])
Root2Y = ExpandTreeRule(R, [("y", Y)])
X2Y_list = ExpandTreeRule(X, [("y", Y_list)])
Ysub2Str = ExpandTreeRule(Ysub, [("str", Str)])


def is_subtype(arg0, arg1):
    if arg0 == arg1:
        return True
    if arg0 == "Ysub" and arg1 == "Y":