Example #1
0
    def test_encode_parent(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, -1, -1, -1], [1, 2, 0, 0], [1, 2, 0, 0],
                               [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 1]],
                              parent.numpy())
Example #2
0
    def test_decode(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        expected_action_sequence = ActionSequence()
        expected_action_sequence.eval(ApplyRule(funcdef))
        expected_action_sequence.eval(GenerateToken("", "f"))
        expected_action_sequence.eval(GenerateToken("", "1"))
        expected_action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        result = encoder.decode(
            encoder.encode_action(action_sequence,
                                  [Token(None, "1", "1")])[:-1, 1:],
            [Token(None, "1", "1")])
        assert \
            expected_action_sequence.action_sequence == result.action_sequence
Example #3
0
    def test_create_leaf(self):
        seq = ActionSequence.create(Leaf("str", "t0 t1"))
        assert [
            ApplyRule(
                ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [
                    ("root", NodeType(Root(), NodeConstraint.Token, False))
                ])),
            GenerateToken("str", "t0 t1")
        ] == seq.action_sequence

        seq = ActionSequence.create(
            Node(
                "value",
                [Field("name", "str", [Leaf("str", "t0"),
                                       Leaf("str", "t1")])]))
        assert [
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])),
            ApplyRule(
                ExpandTreeRule(
                    NodeType("value", NodeConstraint.Node, False),
                    [("name", NodeType("str", NodeConstraint.Token, True))])),
            GenerateToken("str", "t0"),
            GenerateToken("str", "t1"),
            ApplyRule(CloseVariadicFieldRule())
        ] == seq.action_sequence
Example #4
0
    def test_encode_invalid_sequence(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        assert encoder.encode_action(action_sequence,
                                     [Token("", "2", "2")]) is None
Example #5
0
    def test_encode_path(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("constant", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, True)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        action_sequence.eval(GenerateToken("", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        path = encoder.encode_path(action_sequence, 2)

        assert np.array_equal(
            np.array(
                [
                    [-1, -1],  # funcdef
                    [2, -1],  # f
                    [2, -1],  # 1
                    [2, -1],  # 2
                    [2, -1],  # CloseVariadicField
                    [2, -1],  # expr
                    [3, 2],  # f
                    [3, 2],  # CloseVariadicField
                    [2, -1],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
        path = encoder.encode_path(action_sequence, 1)
        assert np.array_equal(
            np.array(
                [
                    [-1],  # funcdef
                    [2],  # f
                    [2],  # 1
                    [2],  # 2
                    [2],  # CloseVariadicField
                    [2],  # expr
                    [3],  # f
                    [3],  # CloseVariadicField
                    [2],  # CloseVariadicField
                ],
                dtype=np.long),
            path.numpy())
Example #6
0
    def test_invalid_close_variadic_field_rule(self):
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(rule))
        with pytest.raises(InvalidActionException):
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
Example #7
0
 def test_generate_ignore_root_type(self):
     action_sequence = ActionSequence()
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(
                 NodeType(Root(), NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])))
     action_sequence.eval(
         ApplyRule(
             ExpandTreeRule(NodeType("op", NodeConstraint.Node, False),
                            [])))
     assert Node("op", []) == action_sequence.generate()
Example #8
0
 def test_create_node(self):
     a = Node("def", [Field("name", "literal", Leaf("str", "foo"))])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [
                 ("name", NodeType("literal", NodeConstraint.Token, False))
             ])),
         GenerateToken("str", "foo")
     ] == seq.action_sequence
Example #9
0
    def test_encode_tree(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, True)),
             ("arg0", NodeType("value", NodeConstraint.Token, True)),
             ("arg1", NodeType("value", NodeConstraint.Token, True))])

        encoder = ActionSequenceEncoder(
            Samples([funcdef, expr], [
                NodeType("def", NodeConstraint.Node, False),
                NodeType("value", NodeConstraint.Token, True),
                NodeType("expr", NodeConstraint.Node, False)
            ], [("", "f"), ("", "2")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("", "f"))
        action_sequence.eval(GenerateToken("", "1"))
        d, m = encoder.encode_tree(action_sequence)

        assert np.array_equal([0, 1, 1], d.numpy())
        assert np.array_equal([[0, 1, 1], [0, 0, 0], [0, 0, 0]], m.numpy())
Example #10
0
    def test_str(self):
        t0 = NodeType("t0", NodeConstraint.Node, False)
        t1 = NodeType("t1", NodeConstraint.Node, False)
        t2 = NodeType("t2", NodeConstraint.Node, True)
        assert "Apply (t0 -> [elem0: t1, elem1: t2*])" == \
            str(ApplyRule(
                ExpandTreeRule(t0,
                               [("elem0", t1),
                                ("elem1", t2)])))

        assert "Generate bar:kind" == str(GenerateToken("kind", "bar"))
Example #11
0
    def test_eval_root(self):
        action_sequence = ActionSequence()
        assert action_sequence.head is None
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(GenerateToken("kind", ""))
        with pytest.raises(InvalidActionException):
            action_sequence = ActionSequence()
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))

        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Node, False)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [ApplyRule(rule)] == action_sequence.action_sequence
        assert action_sequence.parent(0) is None
        assert [[], []] == action_sequence._tree.children[0]
Example #12
0
    def test_clone(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("expr", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))

        action_sequence2 = action_sequence.clone()
        assert action_sequence.generate() == action_sequence2.generate()

        action_sequence2.eval(ApplyRule(rule))
        assert \
            action_sequence._tree.children != action_sequence2._tree.children
        assert \
            action_sequence._tree.parent != action_sequence2._tree.parent
        assert \
            action_sequence.action_sequence != action_sequence2.action_sequence
        assert action_sequence._head_action_index != \
            action_sequence2._head_action_index
        assert action_sequence._head_children_index != \
            action_sequence2._head_children_index
        assert action_sequence.generate() != action_sequence2.generate()
Example #13
0
    def test_generate_token(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, False)),
             ("value", NodeType("args", NodeConstraint.Node, True))])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(GenerateToken("", "foo"))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                GenerateToken("", "foo")] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        with pytest.raises(InvalidActionException):
            action_sequence.eval(GenerateToken("", "bar"))

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(rule))
        with pytest.raises(InvalidActionException):
            action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
Example #14
0
 def test_create_node_with_variadic_fields(self):
     a = Node(
         "list",
         [Field("elems", "literal",
                [Node("str", []), Node("str", [])])])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [
                 ("elems", NodeType("literal", NodeConstraint.Node, True))
             ])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(CloseVariadicFieldRule())
     ] == seq.action_sequence
Example #15
0
    def test_encode_completed_sequence(self):
        none = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                              [])
        encoder = ActionSequenceEncoder(
            Samples([none], [NodeType("value", NodeConstraint.Node, False)],
                    [("", "f")]), 0)
        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(none))
        action = encoder.encode_action(action_sequence, [Token("", "1", "1")])
        parent = encoder.encode_parent(action_sequence)

        assert np.array_equal([[-1, 2, -1, -1], [-1, -1, -1, -1]],
                              action.numpy())
        assert np.array_equal([[-1, -1, -1, -1], [-1, -1, -1, -1]],
                              parent.numpy())
Example #16
0
    def test_generate(self):
        funcdef = ExpandTreeRule(
            NodeType("def", NodeConstraint.Node, False),
            [("name", NodeType("value", NodeConstraint.Token, True)),
             ("body", NodeType("expr", NodeConstraint.Node, True))])
        expr = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("op", NodeType("value", NodeConstraint.Token, False)),
             ("arg0", NodeType("value", NodeConstraint.Token, False)),
             ("arg1", NodeType("value", NodeConstraint.Token, False))])

        action_sequence = ActionSequence()
        action_sequence.eval(ApplyRule(funcdef))
        action_sequence.eval(GenerateToken("name", "f"))
        action_sequence.eval(GenerateToken("name", "_"))
        action_sequence.eval(GenerateToken("name", "0"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        action_sequence.eval(ApplyRule(expr))
        action_sequence.eval(GenerateToken("value", "+"))
        action_sequence.eval(GenerateToken("value", "1"))
        action_sequence.eval(GenerateToken("value", "2"))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None
        assert Node("def", [
            Field("name", "value",
                  [Leaf("name", "f"),
                   Leaf("name", "_"),
                   Leaf("name", "0")]),
            Field("body", "expr", [
                Node("expr", [
                    Field("op", "value", Leaf("value", "+")),
                    Field("arg0", "value", Leaf("value", "1")),
                    Field("arg1", "value", Leaf("value", "2"))
                ])
            ])
        ]) == action_sequence.generate()
Example #17
0
    def decode(self, tensor: torch.LongTensor, reference: List[Token]) \
            -> Optional[ActionSequence]:
        """
        Return the action sequence corresponding to the tensor

        Parameters
        ----------
        tensor: torch.LongTensor
            The encoded tensor with the shape of
            (len(action_sequence), 3). Each action will be encoded by the tuple
            of (ID of the applied rule, ID of the inserted token,
            the index of the word copied from the reference).
            The padding value should be -1.
        reference

        Returns
        -------
        Optional[action_sequence]
            The action sequence corresponding to the tensor
            None if the action sequence cannot be generated.
        """

        retval = ActionSequence()
        for i in range(tensor.shape[0]):
            if tensor[i, 0] > 0:
                # ApplyRule
                rule = self._rule_encoder.decode(tensor[i, 0])
                retval.eval(ApplyRule(rule))
            elif tensor[i, 1] > 0:
                # GenerateToken
                kind, value = self._token_encoder.decode(tensor[i, 1])
                retval.eval(GenerateToken(kind, value))
            elif tensor[i, 2] >= 0:
                # GenerateToken (Copy)
                index = int(tensor[i, 2].numpy())
                if index >= len(reference):
                    logger.debug("reference index is out-of-bounds")
                    return None
                token = reference[index]
                retval.eval(GenerateToken(token.kind, token.raw_value))
            else:
                logger.debug("invalid actions")
                return None

        return retval
    def initialize(self, input: Input) -> Environment:
        self.module.encoder.eval()
        state_list = self.transform_input(input)
        state_tensor = self.collate.collate([state_list])
        state_tensor = self._to(state_tensor)
        with torch.no_grad(), logger.block("encode_state"):
            state_tensor = self.module.encoder(state_tensor)
        state = self.collate.split(state_tensor)[0]

        # Add initial rule
        action_sequence = ActionSequence()
        action_sequence.eval(
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])))
        state["action_sequence"] = action_sequence

        return state
Example #19
0
    def test_variadic_field(self):
        action_sequence = ActionSequence()
        rule = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule))
        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0)] == action_sequence.action_sequence
        assert Parent(0, 0) == action_sequence.parent(1)
        assert [] == action_sequence._tree.children[1]

        action_sequence.eval(ApplyRule(rule0))
        assert 0 == action_sequence.head.action
        assert 0 == action_sequence.head.field
        assert [1, 2] == action_sequence._tree.children[0][0]
        assert [ApplyRule(rule),
                ApplyRule(rule0),
                ApplyRule(rule0)] == action_sequence.action_sequence

        action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
        assert action_sequence.head is None

        action_sequence = ActionSequence()
        rule1 = ExpandTreeRule(
            NodeType("expr", NodeConstraint.Node, False),
            [("elems", NodeType("value", NodeConstraint.Node, True)),
             ("name", NodeType("value", NodeConstraint.Node, False))])
        rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False),
                               [])
        action_sequence.eval(ApplyRule(rule1))
        action_sequence.eval(ApplyRule(rule0))
        action_sequence.eval(ApplyRule(CloseVariadicFieldRule))
        assert 0 == action_sequence.head.action
        assert 1 == action_sequence.head.field
    def enumerate_samples_per_state(self,
                                    rule_pred: torch.Tensor,
                                    token_pred: torch.Tensor,
                                    reference_pred: torch.Tensor,
                                    next_state: Environment,
                                    state: SamplerState[Environment],
                                    enumeration: Enumeration,
                                    k: Optional[int]) \
            -> Generator[DuplicatedSamplerState[Environment], None, None]:
        def indices(pred: torch.Tensor):
            # 0 is unknown token
            if enumeration == Enumeration.Top:
                _, indices = torch.sort(pred[1:], descending=True)
                if k is not None:
                    indices = indices[:k]
                for index in indices:
                    yield index + 1, 1
            elif enumeration == Enumeration.Random:
                indices = list(range(1, len(pred)))
                if k is not None:
                    indices = indices[:k]
                for index in indices:
                    yield index, 1
            else:
                assert k is not None
                with logger.block("normalize_prob"):
                    s = pred[1:].sum().item()
                    if s < self.eps:
                        return
                    ps = (pred[1:] / s - self.eps).numpy()
                    npred = [max(0, p) for p in ps]
                for i, n in enumerate(self.rng.multinomial(k, npred)):
                    if n == 0:
                        continue
                    yield i + 1, n

        with logger.block("enumerate_samples_per_state"):
            head = state.state["action_sequence"].head
            assert head is not None
            head_field = \
                cast(ExpandTreeRule, cast(
                    ApplyRule,
                    state.state["action_sequence"]
                    .action_sequence[head.action]
                ).rule).children[head.field][1]
            if head_field.constraint == NodeConstraint.Token:
                # Generate token
                ref_ids = self.encoder.batch_encode_raw_value(
                    [x.raw_value for x in state.state["reference"]])
                tokens = list(self.encoder._token_encoder.vocab) + \
                    state.state["reference"]
                # the score will be merged into predefined token
                for i, ids in enumerate(ref_ids):
                    for ref_id in ids:
                        # merge token and reference pred
                        # Add to unknown probability
                        # if there is not the corresponding token.
                        token_pred[ref_id] += reference_pred[i]
                        if ref_id != 0:
                            reference_pred[i] = 0.0
                pred = torch.cat([token_pred, reference_pred], dim=0)

                # CloseVariadicFieldRule is a candidate if variadic fields
                if head_field.is_variadic:
                    close_rule_idx = \
                        self.encoder._rule_encoder.encode(
                            CloseVariadicFieldRule())
                    p = rule_pred[close_rule_idx].item()
                    tokens.append(ApplyRule(CloseVariadicFieldRule()))
                    pred = torch.cat([pred, torch.tensor([p])], dim=0)

                with logger.block("exclude_invalid_tokens"):
                    # token
                    for kind, idxes in self.token_kind_to_idx.items():
                        if kind is not None and \
                                not self.is_subtype(kind,
                                                    head_field.type_name):
                            pred[idxes] = 0.0
                    # reference
                    for x, (p, token) in enumerate(
                            zip(pred[len(token_pred):],
                                tokens[len(token_pred):])):
                        x += len(token_pred)
                        if not isinstance(token, ApplyRule):
                            if isinstance(token, Token):
                                t = token.kind
                            else:
                                t = token[0]
                            if t is not None and \
                                    not self.is_subtype(t,
                                                        head_field.type_name):
                                pred[x] = 0.0

                n_action = 0
                for x, n in logger.iterable_block("sample-tokens",
                                                  indices(pred)):
                    # Finish enumeration
                    if n_action == k:
                        return

                    p = pred[x].item()
                    token = tokens[x]

                    if isinstance(token, ApplyRule):
                        action: Action = token
                    elif isinstance(token, Token):
                        action = GenerateToken(token.kind, token.raw_value)
                    else:
                        action = GenerateToken(token[0], token[1])

                    if p == 0.0:
                        continue
                    elif p < self.eps:
                        lp = np.log(self.eps)
                    else:
                        lp = np.log(p)

                    n_action += n
                    next_state = next_state.clone()
                    # TODO we may have to clear outputs
                    next_state["action_sequence"] = \
                        LazyActionSequence(
                            state.state["action_sequence"], action)
                    yield DuplicatedSamplerState(
                        SamplerState(state.score + lp, next_state), n)
            else:
                # Apply rule
                with logger.block("exclude_invalid_rules"):
                    # expand tree rule
                    for kind, idxes in self.rule_kind_to_idx.items():
                        if not (kind is not None and self.is_subtype(
                                kind, head_field.type_name)):
                            rule_pred[idxes] = 0.0
                    # CloseVariadicField
                    idx = self.encoder._rule_encoder.encode(
                        CloseVariadicFieldRule())
                    if not (head_field is not None and head_field.is_variadic):
                        rule_pred[idx] = 0.0

                n_rule = 0
                for x, n in logger.iterable_block("sample-rule",
                                                  indices(rule_pred)):
                    # Finish enumeration
                    if n_rule == k:
                        return

                    p = rule_pred[x].item()
                    if p == 0.0:
                        continue
                    elif p < self.eps:
                        lp = np.log(self.eps)
                    else:
                        lp = np.log(p)

                    n_rule += n
                    rule = self.encoder._rule_encoder.vocab[x]

                    next_state = next_state.clone()
                    next_state["action_sequence"] = \
                        LazyActionSequence(
                            state.state["action_sequence"],
                            ApplyRule(rule))
                    yield DuplicatedSamplerState(
                        SamplerState(state.score + lp, next_state), n)