def test_decode(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) expected_action_sequence = ActionSequence() expected_action_sequence.eval(ApplyRule(funcdef)) expected_action_sequence.eval(GenerateToken("", "f")) expected_action_sequence.eval(GenerateToken("", "1")) expected_action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) result = encoder.decode( encoder.encode_action(action_sequence, [Token(None, "1", "1")])[:-1, 1:], [Token(None, "1", "1")]) assert \ expected_action_sequence.action_sequence == result.action_sequence
def test_encode_path(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("constant", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, True) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(GenerateToken("", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(expr)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) path = encoder.encode_path(action_sequence, 2) assert np.array_equal( np.array( [ [-1, -1], # funcdef [2, -1], # f [2, -1], # 1 [2, -1], # 2 [2, -1], # CloseVariadicField [2, -1], # expr [3, 2], # f [3, 2], # CloseVariadicField [2, -1], # CloseVariadicField ], dtype=np.long), path.numpy()) path = encoder.encode_path(action_sequence, 1) assert np.array_equal( np.array( [ [-1], # funcdef [2], # f [2], # 1 [2], # 2 [2], # CloseVariadicField [2], # expr [3], # f [3], # CloseVariadicField [2], # CloseVariadicField ], dtype=np.long), path.numpy())
def test_create_leaf(self): seq = ActionSequence.create(Leaf("str", "t0 t1")) assert [ ApplyRule( ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [ ("root", NodeType(Root(), NodeConstraint.Token, False)) ])), GenerateToken("str", "t0 t1") ] == seq.action_sequence seq = ActionSequence.create( Node( "value", [Field("name", "str", [Leaf("str", "t0"), Leaf("str", "t1")])])) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule( NodeType("value", NodeConstraint.Node, False), [("name", NodeType("str", NodeConstraint.Token, True))])), GenerateToken("str", "t0"), GenerateToken("str", "t1"), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence
def test_encode_parent(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(GenerateToken("", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) parent = encoder.encode_parent(action_sequence) assert np.array_equal([[-1, -1, -1, -1], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 1]], parent.numpy())
def test_encode_invalid_sequence(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, False)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, True) ], [("", "f")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert encoder.encode_action(action_sequence, [Token("", "2", "2")]) is None
def test_invalid_close_variadic_field_rule(self): rule = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("elems", NodeType("value", NodeConstraint.Node, False))]) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(rule)) with pytest.raises(InvalidActionException): action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
def test_str(self): t0 = NodeType("t0", NodeConstraint.Node, False) t1 = NodeType("t1", NodeConstraint.Node, False) t2 = NodeType("t2", NodeConstraint.Node, True) assert "t0 -> [elem0: t1, elem1: t2*]" == \ str(ExpandTreeRule(t0, [("elem0", t1), ("elem1", t2)])) assert "<close variadic field>" == \ str(CloseVariadicFieldRule())
def test_generate_variadic_token(self): action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("value", NodeType("args", NodeConstraint.Node, True))]) action_sequence.eval(ApplyRule(rule)) action_sequence.eval(GenerateToken("", "foo")) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [1] == action_sequence._tree.children[0][0] assert [ApplyRule(rule), GenerateToken("", "foo")] == action_sequence.action_sequence assert Parent(0, 0) == action_sequence.parent(1) assert [] == action_sequence._tree.children[1] action_sequence.eval(GenerateToken("", "bar")) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [1, 2] == action_sequence._tree.children[0][0] assert [ ApplyRule(rule), GenerateToken("", "foo"), GenerateToken("", "bar") ] == action_sequence.action_sequence action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert 0 == action_sequence.head.action assert 1 == action_sequence.head.field assert [1, 2, 3] == action_sequence._tree.children[0][0] assert [ ApplyRule(rule), GenerateToken("", "foo"), GenerateToken("", "bar"), ApplyRule(CloseVariadicFieldRule()) ] == action_sequence.action_sequence with pytest.raises(InvalidActionException): action_sequence.eval(GenerateToken("", "foo"))
def test_generate(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, False)), ("arg0", NodeType("value", NodeConstraint.Token, False)), ("arg1", NodeType("value", NodeConstraint.Token, False))]) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("name", "f")) action_sequence.eval(GenerateToken("name", "_")) action_sequence.eval(GenerateToken("name", "0")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(expr)) action_sequence.eval(GenerateToken("value", "+")) action_sequence.eval(GenerateToken("value", "1")) action_sequence.eval(GenerateToken("value", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert action_sequence.head is None assert Node("def", [ Field("name", "value", [Leaf("name", "f"), Leaf("name", "_"), Leaf("name", "0")]), Field("body", "expr", [ Node("expr", [ Field("op", "value", Leaf("value", "+")), Field("arg0", "value", Leaf("value", "1")), Field("arg1", "value", Leaf("value", "2")) ]) ]) ]) == action_sequence.generate()
def __init__(self, samples: Samples, token_threshold: int): reserved_labels: List[Union[Unknown, CloseVariadicFieldRule]] = [Unknown()] reserved_labels.append(CloseVariadicFieldRule()) self._rule_encoder = LabelEncoder(samples.rules, reserved_labels=reserved_labels, unknown_index=0) self._node_type_encoder = LabelEncoder(samples.node_types) reserved_labels = [Unknown()] self._token_encoder = LabelEncoder(samples.tokens, min_occurrences=token_threshold, reserved_labels=reserved_labels, unknown_index=0) self.value_to_idx: Dict[str, List[int]] = {} for kind, value in self._token_encoder.vocab[len(reserved_labels):]: idx = self._token_encoder.encode((kind, value)) if value not in self.value_to_idx: self.value_to_idx[value] = [] self.value_to_idx[value].append(idx)
def test_eval_root(self): action_sequence = ActionSequence() assert action_sequence.head is None with pytest.raises(InvalidActionException): action_sequence = ActionSequence() action_sequence.eval(GenerateToken("kind", "")) with pytest.raises(InvalidActionException): action_sequence = ActionSequence() action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Node, False)), ("value", NodeType("args", NodeConstraint.Node, True))]) action_sequence.eval(ApplyRule(rule)) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [ApplyRule(rule)] == action_sequence.action_sequence assert action_sequence.parent(0) is None assert [[], []] == action_sequence._tree.children[0]
def test_variadic_field(self): action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("elems", NodeType("value", NodeConstraint.Node, True))]) rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False), []) action_sequence.eval(ApplyRule(rule)) action_sequence.eval(ApplyRule(rule0)) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [1] == action_sequence._tree.children[0][0] assert [ApplyRule(rule), ApplyRule(rule0)] == action_sequence.action_sequence assert Parent(0, 0) == action_sequence.parent(1) assert [] == action_sequence._tree.children[1] action_sequence.eval(ApplyRule(rule0)) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [1, 2] == action_sequence._tree.children[0][0] assert [ApplyRule(rule), ApplyRule(rule0), ApplyRule(rule0)] == action_sequence.action_sequence action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert action_sequence.head is None action_sequence = ActionSequence() rule1 = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("elems", NodeType("value", NodeConstraint.Node, True)), ("name", NodeType("value", NodeConstraint.Node, False))]) rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False), []) action_sequence.eval(ApplyRule(rule1)) action_sequence.eval(ApplyRule(rule0)) action_sequence.eval(ApplyRule(CloseVariadicFieldRule)) assert 0 == action_sequence.head.action assert 1 == action_sequence.head.field
def test_create_node_with_variadic_fields(self): a = Node( "list", [Field("elems", "literal", [Node("str", []), Node("str", [])])]) seq = ActionSequence.create(a) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [ ("elems", NodeType("literal", NodeConstraint.Node, True)) ])), ApplyRule( ExpandTreeRule(NodeType("str", NodeConstraint.Node, False), [])), ApplyRule( ExpandTreeRule(NodeType("str", NodeConstraint.Node, False), [])), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence
def enumerate_samples_per_state(self, rule_pred: torch.Tensor, token_pred: torch.Tensor, reference_pred: torch.Tensor, next_state: Environment, state: SamplerState[Environment], enumeration: Enumeration, k: Optional[int]) \ -> Generator[DuplicatedSamplerState[Environment], None, None]: def indices(pred: torch.Tensor): # 0 is unknown token if enumeration == Enumeration.Top: _, indices = torch.sort(pred[1:], descending=True) if k is not None: indices = indices[:k] for index in indices: yield index + 1, 1 elif enumeration == Enumeration.Random: indices = list(range(1, len(pred))) if k is not None: indices = indices[:k] for index in indices: yield index, 1 else: assert k is not None with logger.block("normalize_prob"): s = pred[1:].sum().item() if s < self.eps: return ps = (pred[1:] / s - self.eps).numpy() npred = [max(0, p) for p in ps] for i, n in enumerate(self.rng.multinomial(k, npred)): if n == 0: continue yield i + 1, n with logger.block("enumerate_samples_per_state"): head = state.state["action_sequence"].head assert head is not None head_field = \ cast(ExpandTreeRule, cast( ApplyRule, state.state["action_sequence"] .action_sequence[head.action] ).rule).children[head.field][1] if head_field.constraint == NodeConstraint.Token: # Generate token ref_ids = self.encoder.batch_encode_raw_value( [x.raw_value for x in state.state["reference"]]) tokens = list(self.encoder._token_encoder.vocab) + \ state.state["reference"] # the score will be merged into predefined token for i, ids in enumerate(ref_ids): for ref_id in ids: # merge token and reference pred # Add to unknown probability # if there is not the corresponding token. token_pred[ref_id] += reference_pred[i] if ref_id != 0: reference_pred[i] = 0.0 pred = torch.cat([token_pred, reference_pred], dim=0) # CloseVariadicFieldRule is a candidate if variadic fields if head_field.is_variadic: close_rule_idx = \ self.encoder._rule_encoder.encode( CloseVariadicFieldRule()) p = rule_pred[close_rule_idx].item() tokens.append(ApplyRule(CloseVariadicFieldRule())) pred = torch.cat([pred, torch.tensor([p])], dim=0) with logger.block("exclude_invalid_tokens"): # token for kind, idxes in self.token_kind_to_idx.items(): if kind is not None and \ not self.is_subtype(kind, head_field.type_name): pred[idxes] = 0.0 # reference for x, (p, token) in enumerate( zip(pred[len(token_pred):], tokens[len(token_pred):])): x += len(token_pred) if not isinstance(token, ApplyRule): if isinstance(token, Token): t = token.kind else: t = token[0] if t is not None and \ not self.is_subtype(t, head_field.type_name): pred[x] = 0.0 n_action = 0 for x, n in logger.iterable_block("sample-tokens", indices(pred)): # Finish enumeration if n_action == k: return p = pred[x].item() token = tokens[x] if isinstance(token, ApplyRule): action: Action = token elif isinstance(token, Token): action = GenerateToken(token.kind, token.raw_value) else: action = GenerateToken(token[0], token[1]) if p == 0.0: continue elif p < self.eps: lp = np.log(self.eps) else: lp = np.log(p) n_action += n next_state = next_state.clone() # TODO we may have to clear outputs next_state["action_sequence"] = \ LazyActionSequence( state.state["action_sequence"], action) yield DuplicatedSamplerState( SamplerState(state.score + lp, next_state), n) else: # Apply rule with logger.block("exclude_invalid_rules"): # expand tree rule for kind, idxes in self.rule_kind_to_idx.items(): if not (kind is not None and self.is_subtype( kind, head_field.type_name)): rule_pred[idxes] = 0.0 # CloseVariadicField idx = self.encoder._rule_encoder.encode( CloseVariadicFieldRule()) if not (head_field is not None and head_field.is_variadic): rule_pred[idx] = 0.0 n_rule = 0 for x, n in logger.iterable_block("sample-rule", indices(rule_pred)): # Finish enumeration if n_rule == k: return p = rule_pred[x].item() if p == 0.0: continue elif p < self.eps: lp = np.log(self.eps) else: lp = np.log(p) n_rule += n rule = self.encoder._rule_encoder.vocab[x] next_state = next_state.clone() next_state["action_sequence"] = \ LazyActionSequence( state.state["action_sequence"], ApplyRule(rule)) yield DuplicatedSamplerState( SamplerState(state.score + lp, next_state), n)