예제 #1
0
    def test_create_leaf(self):
        seq = ActionSequence.create(Leaf("str", "t0 t1"))
        assert [
            ApplyRule(
                ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [
                    ("root", NodeType(Root(), NodeConstraint.Token, False))
                ])),
            GenerateToken("str", "t0 t1")
        ] == seq.action_sequence

        seq = ActionSequence.create(
            Node(
                "value",
                [Field("name", "str", [Leaf("str", "t0"),
                                       Leaf("str", "t1")])]))
        assert [
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])),
            ApplyRule(
                ExpandTreeRule(
                    NodeType("value", NodeConstraint.Node, False),
                    [("name", NodeType("str", NodeConstraint.Token, True))])),
            GenerateToken("str", "t0"),
            GenerateToken("str", "t1"),
            ApplyRule(CloseVariadicFieldRule())
        ] == seq.action_sequence
예제 #2
0
    def test_encode_parent(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, -1, -1, -1], [1, 2, 0, 0], [1, 2, 0, 0],
                               [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 1]],
                              parent.numpy())
예제 #3
0
    def test_encode_invalid_sequence(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        assert encoder.encode_action(action_sequence,
                                     [Token("", "2", "2")]) is None
예제 #4
0
    def test_encode_completed_sequence(self):
        none = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                              [])
        encoder = ActionSequenceEncoder(
            Samples([none], [NodeType("value", NodeConstraint.Node, False)],
                    [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(none))
        action = encoder.encode_action(action_sequence, [Token("", "1", "1")])
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, 2, -1, -1], [-1, -1, -1, -1]],
                              action.numpy())
        assert np.array_equal([[-1, -1, -1, -1], [-1, -1, -1, -1]],
                              parent.numpy())
예제 #5
0
    def encode_path(self, action_sequence: ActionSequence, max_depth: int) \
            -> torch.Tensor:
        """
        Return the tensor encoding the each action

        Parameters
        ----------
        action_sequence: action_sequence
            The action_sequence containing action sequence to be encoded
        max_depth: int

        Returns
        -------
        torch.Tensor
            The encoded tensor. The shape of tensor is
            (len(action_sequence), max_depth).
            [i, :] encodes the path from the root node to i-th node.
            Each node represented by the rule id.
            The padding value is -1.
        """
        L = len(action_sequence.action_sequence)
        retval = torch.ones(L, max_depth).long() * -1
        for i in range(L):
            parent_opt = action_sequence.parent(i)
            if parent_opt is not None:
                p = action_sequence.action_sequence[parent_opt.action]
                if isinstance(p, ApplyRule):
                    retval[i, 0] = self._rule_encoder.encode(p.rule)
                retval[i, 1:] = retval[parent_opt.action, :max_depth - 1]

        return retval
예제 #6
0
    def test_encode_empty_sequence(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, False)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, False),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action = encoder.encode_action(action_sequence, [Token("", "1", "1")])
        parent = encoder.encode_parent(action_sequence)
        d, m = encoder.encode_tree(action_sequence)

        assert np.array_equal([[-1, -1, -1, -1]], action.numpy())
        assert np.array_equal([[-1, -1, -1, -1]], parent.numpy())
        assert np.array_equal(np.zeros((0, )), d.numpy())
        assert np.array_equal(np.zeros((0, 0)), m.numpy())
예제 #7
0
def get_samples(dataset: torch.utils.data.Dataset,
                parser: Parser[Any]) -> Samples:
    rules: List[Rule] = []
    node_types = []
    tokens: List[Tuple[str, str]] = []

    for sample in dataset:
        ground_truth = sample["ground_truth"]
        ast = parser.parse(ground_truth)
        if ast is None:
            continue
        action_sequence = ActionSequence.create(ast)
        for action in action_sequence.action_sequence:
            if isinstance(action, ApplyRule):
                rule = action.rule
                if not isinstance(rule, CloseVariadicFieldRule):
                    rules.append(rule)
                    node_types.append(rule.parent)
                    for _, child in rule.children:
                        node_types.append(child)
            else:
                assert action.kind is not None
                tokens.append((action.kind, action.value))

    return Samples(rules, node_types, tokens)
예제 #8
0
 def forward(self, ground_truth: Code) -> ActionSequence:
     code = ground_truth
     ast = self.parser.parse(code)
     if ast is None:
         raise RuntimeError(
             f"cannot convert to ActionSequence: {ground_truth}")
     return cast(ActionSequence, ActionSequence.create(ast))
예제 #9
0
    def encode_tree(self, action_sequence: ActionSequence) \
            -> Union[torch.Tensor, torch.Tensor]:
        """
        Return the tensor adjacency matrix of the action sequence

        Parameters
        ----------
        action_sequence: action_sequence
            The action_sequence containing action sequence to be encoded

        Returns
        -------
        depth: torch.Tensor
            The depth of each action. The shape is (len(action_sequence),).
        adjacency_matrix: torch.Tensor
            The encoded tensor. The shape of tensor is
            (len(action_sequence), len(action_sequence)). If i th action is
            a parent of j th action, (i, j) element will be 1. the element
            will be 0 otherwise.
        """
        L = len(action_sequence.action_sequence)
        depth = torch.zeros(L)
        m = torch.zeros(L, L)

        for i in range(L):
            p = action_sequence.parent(i)
            if p is not None:
                depth[i] = depth[p.action] + 1
                m[p.action, i] = 1

        return depth, m
예제 #10
0
    def initialize(self, input: Input) -> Environment:
        self.module.encoder.eval()
        state_list = self.transform_input(input)
        state_tensor = self.collate.collate([state_list])
        state_tensor = self._to(state_tensor)
        with torch.no_grad(), logger.block("encode_state"):
            state_tensor = self.module.encoder(state_tensor)
        state = self.collate.split(state_tensor)[0]

        # Add initial rule
        action_sequence = ActionSequence()
        action_sequence.eval(
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])))
        state["action_sequence"] = action_sequence

        return state
예제 #11
0
def get_samples(dataset: Dataset,
                parser: Parser[csgAST],
                reference: bool = False) -> Samples:
    rules: List[Rule] = []
    node_types = []
    srule = set()
    sntype = set()
    tokens = [("size", x) for x in dataset.size_candidates]
    tokens.extend([("length", x) for x in dataset.length_candidates])
    tokens.extend([("degree", x) for x in dataset.degree_candidates])

    if reference:
        # TODO use expander
        xs = [
            Circle(1),
            Rectangle(1, 2),
            Translation(1, 1, Reference(0)),
            Rotation(45, Reference(1)),
            Union(Reference(0), Reference(1)),
            Difference(Reference(0), Reference(1))
        ]
    else:
        xs = [
            Circle(1),
            Rectangle(1, 2),
            Translation(1, 1, Circle(1)),
            Rotation(45, Circle(1)),
            Union(Circle(1), Circle(1)),
            Difference(Circle(1), Circle(1))
        ]

    for x in xs:
        ast = parser.parse(x)
        if ast is None:
            continue
        action_sequence = ActionSequence.create(ast)
        for action in action_sequence.action_sequence:
            if isinstance(action, ApplyRule):
                rule = action.rule
                if not isinstance(rule, CloseVariadicFieldRule):
                    if rule not in srule:
                        rules.append(rule)
                        srule.add(rule)
                    if rule.parent not in sntype:
                        node_types.append(rule.parent)
                        sntype.add(rule.parent)
                    for _, child in rule.children:
                        if child not in sntype:
                            node_types.append(child)
                            sntype.add(child)
    tokens = list(set(tokens))
    tokens.sort()

    return Samples(rules, node_types, tokens)
예제 #12
0
    def test_encode_tree(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        d, m = encoder.encode_tree(action_sequence)

        assert np.array_equal([0, 1, 1], d.numpy())
        assert np.array_equal([[0, 1, 1], [0, 0, 0], [0, 0, 0]], m.numpy())
예제 #13
0
 def test_generate_ignore_root_type(self):
     action_sequence = ActionSequence()
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(
                 NodeType(Root(), NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])))
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(NodeType("op", NodeConstraint.Node, False),
                            [])))
     assert Node("op", []) == action_sequence.generate()
예제 #14
0
 def test_create_node(self):
     a = Node("def", [Field("name", "literal", Leaf("str", "foo"))])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [
                 ("name", NodeType("literal", NodeConstraint.Token, False))
             ])),
         GenerateToken("str", "foo")
     ] == seq.action_sequence
예제 #15
0
    def test_invalid_close_variadic_field_rule(self):
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(rule))
        with pytest.raises(InvalidActionException):
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
예제 #16
0
    def decode(self, tensor: torch.LongTensor, reference: List[Token]) \
            -> Optional[ActionSequence]:
        """
        Return the action sequence corresponding to the tensor

        Parameters
        ----------
        tensor: torch.LongTensor
            The encoded tensor with the shape of
            (len(action_sequence), 3). Each action will be encoded by the tuple
            of (ID of the applied rule, ID of the inserted token,
            the index of the word copied from the reference).
            The padding value should be -1.
        reference

        Returns
        -------
        Optional[action_sequence]
            The action sequence corresponding to the tensor
            None if the action sequence cannot be generated.
        """

        retval = ActionSequence()
        for i in range(tensor.shape[0]):
            if tensor[i, 0] > 0:
                # ApplyRule
                rule = self._rule_encoder.decode(tensor[i, 0])
                retval.eval(ApplyRule(rule))
            elif tensor[i, 1] > 0:
                # GenerateToken
                kind, value = self._token_encoder.decode(tensor[i, 1])
                retval.eval(GenerateToken(kind, value))
            elif tensor[i, 2] >= 0:
                # GenerateToken (Copy)
                index = int(tensor[i, 2].numpy())
                if index >= len(reference):
                    logger.debug("reference index is out-of-bounds")
                    return None
                token = reference[index]
                retval.eval(GenerateToken(token.kind, token.raw_value))
            else:
                logger.debug("invalid actions")
                return None

        return retval
예제 #17
0
 def test_create_node_with_variadic_fields(self):
     a = Node(
         "list",
         [Field("elems", "literal",
                [Node("str", []), Node("str", [])])])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [
                 ("elems", NodeType("literal", NodeConstraint.Node, True))
             ])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(CloseVariadicFieldRule())
     ] == seq.action_sequence
예제 #18
0
    def test_clone(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("expr", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))

        action_sequence2 = action_sequence.clone()
        assert action_sequence.generate() == action_sequence2.generate()

        action_sequence2.eval(ApplyRule(rule))
        assert \
            action_sequence._tree.children != action_sequence2._tree.children
        assert \
            action_sequence._tree.parent != action_sequence2._tree.parent
        assert \
            action_sequence.action_sequence != action_sequence2.action_sequence
        assert action_sequence._head_action_index != \
            action_sequence2._head_action_index
        assert action_sequence._head_children_index != \
            action_sequence2._head_children_index
        assert action_sequence.generate() != action_sequence2.generate()
예제 #19
0
    def test_generate_token(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, False)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(GenerateToken("", "foo"))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                GenerateToken("", "foo")] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        with pytest.raises(InvalidActionException):
            action_sequence.eval(GenerateToken("", "bar"))

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(rule))
        with pytest.raises(InvalidActionException):
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
예제 #20
0
    def test_decode(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        expected_action_sequence = ActionSequence()
        expected_action_sequence.eval(ApplyRule(funcdef))
        expected_action_sequence.eval(GenerateToken("", "f"))
        expected_action_sequence.eval(GenerateToken("", "1"))
        expected_action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        result = encoder.decode(
            encoder.encode_action(action_sequence,
                                  [Token(None, "1", "1")])[:-1, 1:],
            [Token(None, "1", "1")])
        assert \
            expected_action_sequence.action_sequence == result.action_sequence
예제 #21
0
    def test_eval_root(self):
        action_sequence = ActionSequence()
        assert action_sequence.head is None
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(GenerateToken("kind", ""))
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Node, False)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [ApplyRule(rule)] == action_sequence.action_sequence
        assert action_sequence.parent(0) is None
        assert [[], []] == action_sequence._tree.children[0]
예제 #22
0
    def test_generate(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("name", "f"))
        action_sequence.eval(GenerateToken("name", "_"))
        action_sequence.eval(GenerateToken("name", "0"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("value", "+"))
        action_sequence.eval(GenerateToken("value", "1"))
        action_sequence.eval(GenerateToken("value", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None
        assert Node("def", [
            Field("name", "value",
                  [Leaf("name", "f"),
                   Leaf("name", "_"),
                   Leaf("name", "0")]),
            Field("body", "expr", [
                Node("expr", [
                    Field("op", "value", Leaf("value", "+")),
                    Field("arg0", "value", Leaf("value", "1")),
                    Field("arg1", "value", Leaf("value", "2"))
                ])
            ])
        ]) == action_sequence.generate()
예제 #23
0
    def test_variadic_field(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0)] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1, 2] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0),
                ApplyRule(rule0)] == action_sequence.action_sequence

        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None

        action_sequence = ActionSequence()
        rule1 = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True)),
             ("name", NodeType("value", NodeConstraint.Node, False))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule1))
        action_sequence.eval(ApplyRule(rule0))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
예제 #24
0
    def test_encode_path(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("constant", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        path = encoder.encode_path(action_sequence, 2)

        assert np.array_equal(
            np.array(
                [
                    [-1, -1],  # funcdef
                    [2, -1],  # f
                    [2, -1],  # 1
                    [2, -1],  # 2
                    [2, -1],  # CloseVariadicField
                    [2, -1],  # expr
                    [3, 2],  # f
                    [3, 2],  # CloseVariadicField
                    [2, -1],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
        path = encoder.encode_path(action_sequence, 1)
        assert np.array_equal(
            np.array(
                [
                    [-1],  # funcdef
                    [2],  # f
                    [2],  # 1
                    [2],  # 2
                    [2],  # CloseVariadicField
                    [2],  # expr
                    [3],  # f
                    [3],  # CloseVariadicField
                    [2],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
예제 #25
0
    def encode_action(self,
                      action_sequence: ActionSequence,
                      reference: List[Token]) \
            -> Optional[torch.Tensor]:
        """
        Return the tensor encoded the action sequence

        Parameters
        ----------
        action_sequence: action_sequence
            The action_sequence containing action sequence to be encoded
        reference

        Returns
        -------
        Optional[torch.Tensor]
            The encoded tensor. The shape of tensor is
            (len(action_sequence) + 1, 4). Each action will be encoded by
            the tuple of (ID of the node types, ID of the applied rule,
            ID of the inserted token, the index of the word copied from
            the reference. The padding value should be -1.
            None if the action sequence cannot be encoded.
        """
        reference_value = [token.raw_value for token in reference]
        action = \
            torch.ones(len(action_sequence.action_sequence) + 1, 4).long() \
            * -1
        for i in range(len(action_sequence.action_sequence)):
            a = action_sequence.action_sequence[i]
            parent = action_sequence.parent(i)
            if parent is not None:
                parent_action = \
                    cast(ApplyRule,
                         action_sequence.action_sequence[parent.action])
                parent_rule = cast(ExpandTreeRule, parent_action.rule)
                action[i, 0] = self._node_type_encoder.encode(
                    parent_rule.children[parent.field][1])

            if isinstance(a, ApplyRule):
                rule = a.rule
                action[i, 1] = self._rule_encoder.encode(rule)
            else:
                encoded_token = \
                    int(self._token_encoder.encode((a.kind, a.value)).numpy())

                if encoded_token != 0:
                    action[i, 2] = encoded_token

                # Unknown token
                if a.value in reference_value:
                    # TODO use kind in reference
                    action[i, 3] = \
                        reference_value.index(cast(str, a.value))

                if encoded_token == 0 and \
                        a.value not in reference_value:
                    logger.debug("cannot encode token")
                    return None

        head = action_sequence.head
        length = len(action_sequence.action_sequence)
        if head is not None:
            head_action = \
                cast(ApplyRule,
                     action_sequence.action_sequence[head.action])
            head_rule = cast(ExpandTreeRule, head_action.rule)
            action[length, 0] = self._node_type_encoder.encode(
                head_rule.children[head.field][1])

        return action