示例#1
0
    def test_decode(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        expected_action_sequence = ActionSequence()
        expected_action_sequence.eval(ApplyRule(funcdef))
        expected_action_sequence.eval(GenerateToken("", "f"))
        expected_action_sequence.eval(GenerateToken("", "1"))
        expected_action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        result = encoder.decode(
            encoder.encode_action(action_sequence,
                                  [Token(None, "1", "1")])[:-1, 1:],
            [Token(None, "1", "1")])
        assert \
            expected_action_sequence.action_sequence == result.action_sequence
示例#2
0
    def test_encode_path(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("constant", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        path = encoder.encode_path(action_sequence, 2)

        assert np.array_equal(
            np.array(
                [
                    [-1, -1],  # funcdef
                    [2, -1],  # f
                    [2, -1],  # 1
                    [2, -1],  # 2
                    [2, -1],  # CloseVariadicField
                    [2, -1],  # expr
                    [3, 2],  # f
                    [3, 2],  # CloseVariadicField
                    [2, -1],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
        path = encoder.encode_path(action_sequence, 1)
        assert np.array_equal(
            np.array(
                [
                    [-1],  # funcdef
                    [2],  # f
                    [2],  # 1
                    [2],  # 2
                    [2],  # CloseVariadicField
                    [2],  # expr
                    [3],  # f
                    [3],  # CloseVariadicField
                    [2],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
示例#3
0
    def test_create_leaf(self):
        seq = ActionSequence.create(Leaf("str", "t0 t1"))
        assert [
            ApplyRule(
                ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [
                    ("root", NodeType(Root(), NodeConstraint.Token, False))
                ])),
            GenerateToken("str", "t0 t1")
        ] == seq.action_sequence

        seq = ActionSequence.create(
            Node(
                "value",
                [Field("name", "str", [Leaf("str", "t0"),
                                       Leaf("str", "t1")])]))
        assert [
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])),
            ApplyRule(
                ExpandTreeRule(
                    NodeType("value", NodeConstraint.Node, False),
                    [("name", NodeType("str", NodeConstraint.Token, True))])),
            GenerateToken("str", "t0"),
            GenerateToken("str", "t1"),
            ApplyRule(CloseVariadicFieldRule())
        ] == seq.action_sequence
示例#4
0
    def test_encode_parent(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, -1, -1, -1], [1, 2, 0, 0], [1, 2, 0, 0],
                               [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 1]],
                              parent.numpy())
示例#5
0
    def test_encode_invalid_sequence(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        assert encoder.encode_action(action_sequence,
                                     [Token("", "2", "2")]) is None
示例#6
0
    def test_invalid_close_variadic_field_rule(self):
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(rule))
        with pytest.raises(InvalidActionException):
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
示例#7
0
 def test_str(self):
     t0 = NodeType("t0", NodeConstraint.Node, False)
     t1 = NodeType("t1", NodeConstraint.Node, False)
     t2 = NodeType("t2", NodeConstraint.Node, True)
     assert "t0 -> [elem0: t1, elem1: t2*]" == \
         str(ExpandTreeRule(t0, [("elem0", t1),
                                 ("elem1", t2)]))
     assert "<close variadic field>" == \
         str(CloseVariadicFieldRule())
示例#8
0
    def test_generate_variadic_token(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(GenerateToken("", "foo"))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                GenerateToken("", "foo")] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        action_sequence.eval(GenerateToken("", "bar"))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1, 2] == action_sequence._tree.children[0][0]
        assert [
            ApplyRule(rule),
            GenerateToken("", "foo"),
            GenerateToken("", "bar")
        ] == action_sequence.action_sequence

        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
        assert [1, 2, 3] == action_sequence._tree.children[0][0]
        assert [
            ApplyRule(rule),
            GenerateToken("", "foo"),
            GenerateToken("", "bar"),
            ApplyRule(CloseVariadicFieldRule())
        ] == action_sequence.action_sequence

        with pytest.raises(InvalidActionException):
            action_sequence.eval(GenerateToken("", "foo"))
示例#9
0
    def test_generate(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("name", "f"))
        action_sequence.eval(GenerateToken("name", "_"))
        action_sequence.eval(GenerateToken("name", "0"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("value", "+"))
        action_sequence.eval(GenerateToken("value", "1"))
        action_sequence.eval(GenerateToken("value", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None
        assert Node("def", [
            Field("name", "value",
                  [Leaf("name", "f"),
                   Leaf("name", "_"),
                   Leaf("name", "0")]),
            Field("body", "expr", [
                Node("expr", [
                    Field("op", "value", Leaf("value", "+")),
                    Field("arg0", "value", Leaf("value", "1")),
                    Field("arg1", "value", Leaf("value", "2"))
                ])
            ])
        ]) == action_sequence.generate()
示例#10
0
 def __init__(self, samples: Samples, token_threshold: int):
     reserved_labels: List[Union[Unknown,
                                 CloseVariadicFieldRule]] = [Unknown()]
     reserved_labels.append(CloseVariadicFieldRule())
     self._rule_encoder = LabelEncoder(samples.rules,
                                       reserved_labels=reserved_labels,
                                       unknown_index=0)
     self._node_type_encoder = LabelEncoder(samples.node_types)
     reserved_labels = [Unknown()]
     self._token_encoder = LabelEncoder(samples.tokens,
                                        min_occurrences=token_threshold,
                                        reserved_labels=reserved_labels,
                                        unknown_index=0)
     self.value_to_idx: Dict[str, List[int]] = {}
     for kind, value in self._token_encoder.vocab[len(reserved_labels):]:
         idx = self._token_encoder.encode((kind, value))
         if value not in self.value_to_idx:
             self.value_to_idx[value] = []
         self.value_to_idx[value].append(idx)
示例#11
0
    def test_eval_root(self):
        action_sequence = ActionSequence()
        assert action_sequence.head is None
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(GenerateToken("kind", ""))
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Node, False)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [ApplyRule(rule)] == action_sequence.action_sequence
        assert action_sequence.parent(0) is None
        assert [[], []] == action_sequence._tree.children[0]
示例#12
0
    def test_variadic_field(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0)] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1, 2] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0),
                ApplyRule(rule0)] == action_sequence.action_sequence

        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None

        action_sequence = ActionSequence()
        rule1 = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True)),
             ("name", NodeType("value", NodeConstraint.Node, False))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule1))
        action_sequence.eval(ApplyRule(rule0))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
示例#13
0
 def test_create_node_with_variadic_fields(self):
     a = Node(
         "list",
         [Field("elems", "literal",
                [Node("str", []), Node("str", [])])])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [
                 ("elems", NodeType("literal", NodeConstraint.Node, True))
             ])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(CloseVariadicFieldRule())
     ] == seq.action_sequence
示例#14
0
    def enumerate_samples_per_state(self,
                                    rule_pred: torch.Tensor,
                                    token_pred: torch.Tensor,
                                    reference_pred: torch.Tensor,
                                    next_state: Environment,
                                    state: SamplerState[Environment],
                                    enumeration: Enumeration,
                                    k: Optional[int]) \
            -> Generator[DuplicatedSamplerState[Environment], None, None]:
        def indices(pred: torch.Tensor):
            # 0 is unknown token
            if enumeration == Enumeration.Top:
                _, indices = torch.sort(pred[1:], descending=True)
                if k is not None:
                    indices = indices[:k]
                for index in indices:
                    yield index + 1, 1
            elif enumeration == Enumeration.Random:
                indices = list(range(1, len(pred)))
                if k is not None:
                    indices = indices[:k]
                for index in indices:
                    yield index, 1
            else:
                assert k is not None
                with logger.block("normalize_prob"):
                    s = pred[1:].sum().item()
                    if s < self.eps:
                        return
                    ps = (pred[1:] / s - self.eps).numpy()
                    npred = [max(0, p) for p in ps]
                for i, n in enumerate(self.rng.multinomial(k, npred)):
                    if n == 0:
                        continue
                    yield i + 1, n

        with logger.block("enumerate_samples_per_state"):
            head = state.state["action_sequence"].head
            assert head is not None
            head_field = \
                cast(ExpandTreeRule, cast(
                    ApplyRule,
                    state.state["action_sequence"]
                    .action_sequence[head.action]
                ).rule).children[head.field][1]
            if head_field.constraint == NodeConstraint.Token:
                # Generate token
                ref_ids = self.encoder.batch_encode_raw_value(
                    [x.raw_value for x in state.state["reference"]])
                tokens = list(self.encoder._token_encoder.vocab) + \
                    state.state["reference"]
                # the score will be merged into predefined token
                for i, ids in enumerate(ref_ids):
                    for ref_id in ids:
                        # merge token and reference pred
                        # Add to unknown probability
                        # if there is not the corresponding token.
                        token_pred[ref_id] += reference_pred[i]
                        if ref_id != 0:
                            reference_pred[i] = 0.0
                pred = torch.cat([token_pred, reference_pred], dim=0)

                # CloseVariadicFieldRule is a candidate if variadic fields
                if head_field.is_variadic:
                    close_rule_idx = \
                        self.encoder._rule_encoder.encode(
                            CloseVariadicFieldRule())
                    p = rule_pred[close_rule_idx].item()
                    tokens.append(ApplyRule(CloseVariadicFieldRule()))
                    pred = torch.cat([pred, torch.tensor([p])], dim=0)

                with logger.block("exclude_invalid_tokens"):
                    # token
                    for kind, idxes in self.token_kind_to_idx.items():
                        if kind is not None and \
                                not self.is_subtype(kind,
                                                    head_field.type_name):
                            pred[idxes] = 0.0
                    # reference
                    for x, (p, token) in enumerate(
                            zip(pred[len(token_pred):],
                                tokens[len(token_pred):])):
                        x += len(token_pred)
                        if not isinstance(token, ApplyRule):
                            if isinstance(token, Token):
                                t = token.kind
                            else:
                                t = token[0]
                            if t is not None and \
                                    not self.is_subtype(t,
                                                        head_field.type_name):
                                pred[x] = 0.0

                n_action = 0
                for x, n in logger.iterable_block("sample-tokens",
                                                  indices(pred)):
                    # Finish enumeration
                    if n_action == k:
                        return

                    p = pred[x].item()
                    token = tokens[x]

                    if isinstance(token, ApplyRule):
                        action: Action = token
                    elif isinstance(token, Token):
                        action = GenerateToken(token.kind, token.raw_value)
                    else:
                        action = GenerateToken(token[0], token[1])

                    if p == 0.0:
                        continue
                    elif p < self.eps:
                        lp = np.log(self.eps)
                    else:
                        lp = np.log(p)

                    n_action += n
                    next_state = next_state.clone()
                    # TODO we may have to clear outputs
                    next_state["action_sequence"] = \
                        LazyActionSequence(
                            state.state["action_sequence"], action)
                    yield DuplicatedSamplerState(
                        SamplerState(state.score + lp, next_state), n)
            else:
                # Apply rule
                with logger.block("exclude_invalid_rules"):
                    # expand tree rule
                    for kind, idxes in self.rule_kind_to_idx.items():
                        if not (kind is not None and self.is_subtype(
                                kind, head_field.type_name)):
                            rule_pred[idxes] = 0.0
                    # CloseVariadicField
                    idx = self.encoder._rule_encoder.encode(
                        CloseVariadicFieldRule())
                    if not (head_field is not None and head_field.is_variadic):
                        rule_pred[idx] = 0.0

                n_rule = 0
                for x, n in logger.iterable_block("sample-rule",
                                                  indices(rule_pred)):
                    # Finish enumeration
                    if n_rule == k:
                        return

                    p = rule_pred[x].item()
                    if p == 0.0:
                        continue
                    elif p < self.eps:
                        lp = np.log(self.eps)
                    else:
                        lp = np.log(p)

                    n_rule += n
                    rule = self.encoder._rule_encoder.vocab[x]

                    next_state = next_state.clone()
                    next_state["action_sequence"] = \
                        LazyActionSequence(
                            state.state["action_sequence"],
                            ApplyRule(rule))
                    yield DuplicatedSamplerState(
                        SamplerState(state.score + lp, next_state), n)