def test_encode_tree(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) d, m = encoder.encode_tree(action_sequence) assert np.array_equal([0, 1, 1], d.numpy()) assert np.array_equal([[0, 1, 1], [0, 0, 0], [0, 0, 0]], m.numpy())
def test_encode_empty_sequence(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, False)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, False)), ("arg0", NodeType("value", NodeConstraint.Token, False)), ("arg1", NodeType("value", NodeConstraint.Token, False))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, False), NodeType("expr", NodeConstraint.Node, False) ], [("", "f")]), 0) action_sequence = ActionSequence() action = encoder.encode_action(action_sequence, [Token("", "1", "1")]) parent = encoder.encode_parent(action_sequence) d, m = encoder.encode_tree(action_sequence) assert np.array_equal([[-1, -1, -1, -1]], action.numpy()) assert np.array_equal([[-1, -1, -1, -1]], parent.numpy()) assert np.array_equal(np.zeros((0, )), d.numpy()) assert np.array_equal(np.zeros((0, 0)), m.numpy())