def test_create_leaf(self): seq = ActionSequence.create(Leaf("str", "t0 t1")) assert [ ApplyRule( ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [ ("root", NodeType(Root(), NodeConstraint.Token, False)) ])), GenerateToken("str", "t0 t1") ] == seq.action_sequence seq = ActionSequence.create( Node( "value", [Field("name", "str", [Leaf("str", "t0"), Leaf("str", "t1")])])) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule( NodeType("value", NodeConstraint.Node, False), [("name", NodeType("str", NodeConstraint.Token, True))])), GenerateToken("str", "t0"), GenerateToken("str", "t1"), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence
def test_encode_tree(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) d, m = encoder.encode_tree(action_sequence) assert np.array_equal([0, 1, 1], d.numpy()) assert np.array_equal([[0, 1, 1], [0, 0, 0], [0, 0, 0]], m.numpy())
def test_decode(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) expected_action_sequence = ActionSequence() expected_action_sequence.eval(ApplyRule(funcdef)) expected_action_sequence.eval(GenerateToken("", "f")) expected_action_sequence.eval(GenerateToken("", "1")) expected_action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) result = encoder.decode( encoder.encode_action(action_sequence, [Token(None, "1", "1")])[:-1, 1:], [Token(None, "1", "1")]) assert \ expected_action_sequence.action_sequence == result.action_sequence
def test_encode_parent(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(GenerateToken("", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) parent = encoder.encode_parent(action_sequence) assert np.array_equal([[-1, -1, -1, -1], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 1]], parent.numpy())
def test_encode_invalid_sequence(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, False)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, True) ], [("", "f")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert encoder.encode_action(action_sequence, [Token("", "2", "2")]) is None
def test_encode_path(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("constant", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, True) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(GenerateToken("", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(expr)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) path = encoder.encode_path(action_sequence, 2) assert np.array_equal( np.array( [ [-1, -1], # funcdef [2, -1], # f [2, -1], # 1 [2, -1], # 2 [2, -1], # CloseVariadicField [2, -1], # expr [3, 2], # f [3, 2], # CloseVariadicField [2, -1], # CloseVariadicField ], dtype=np.long), path.numpy()) path = encoder.encode_path(action_sequence, 1) assert np.array_equal( np.array( [ [-1], # funcdef [2], # f [2], # 1 [2], # 2 [2], # CloseVariadicField [2], # expr [3], # f [3], # CloseVariadicField [2], # CloseVariadicField ], dtype=np.long), path.numpy())
def test_eq(self): assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [ ("f0", NodeType("bar", NodeConstraint.Node, False))]) == \ ExpandTreeRule( NodeType("foo", NodeConstraint.Node, False), [("f0", NodeType("bar", NodeConstraint.Node, False))]) assert GenerateToken("", "foo") == GenerateToken("", "foo") assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [ ("f0", NodeType("bar", NodeConstraint.Node, False))]) != \ ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), []) assert GenerateToken("", "foo") != GenerateToken("", "bar") assert ExpandTreeRule(NodeType("foo", NodeConstraint.Node, False), [ ("f0", NodeType("bar", NodeConstraint.Node, False)) ]) != GenerateToken("", "foo") assert 0 != ExpandTreeRule( NodeType("foo", NodeConstraint.Node, False), [("f0", NodeType("bar", NodeConstraint.Node, False))])
def decode(self, tensor: torch.LongTensor, reference: List[Token]) \ -> Optional[ActionSequence]: """ Return the action sequence corresponding to the tensor Parameters ---------- tensor: torch.LongTensor The encoded tensor with the shape of (len(action_sequence), 3). Each action will be encoded by the tuple of (ID of the applied rule, ID of the inserted token, the index of the word copied from the reference). The padding value should be -1. reference Returns ------- Optional[action_sequence] The action sequence corresponding to the tensor None if the action sequence cannot be generated. """ retval = ActionSequence() for i in range(tensor.shape[0]): if tensor[i, 0] > 0: # ApplyRule rule = self._rule_encoder.decode(tensor[i, 0]) retval.eval(ApplyRule(rule)) elif tensor[i, 1] > 0: # GenerateToken kind, value = self._token_encoder.decode(tensor[i, 1]) retval.eval(GenerateToken(kind, value)) elif tensor[i, 2] >= 0: # GenerateToken (Copy) index = int(tensor[i, 2].numpy()) if index >= len(reference): logger.debug("reference index is out-of-bounds") return None token = reference[index] retval.eval(GenerateToken(token.kind, token.raw_value)) else: logger.debug("invalid actions") return None return retval
def test_str(self): t0 = NodeType("t0", NodeConstraint.Node, False) t1 = NodeType("t1", NodeConstraint.Node, False) t2 = NodeType("t2", NodeConstraint.Node, True) assert "Apply (t0 -> [elem0: t1, elem1: t2*])" == \ str(ApplyRule( ExpandTreeRule(t0, [("elem0", t1), ("elem1", t2)]))) assert "Generate bar:kind" == str(GenerateToken("kind", "bar"))
def test_generate_token(self): action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, False)), ("value", NodeType("args", NodeConstraint.Node, True))]) action_sequence.eval(ApplyRule(rule)) action_sequence.eval(GenerateToken("", "foo")) assert 0 == action_sequence.head.action assert 1 == action_sequence.head.field assert [1] == action_sequence._tree.children[0][0] assert [ApplyRule(rule), GenerateToken("", "foo")] == action_sequence.action_sequence assert Parent(0, 0) == action_sequence.parent(1) assert [] == action_sequence._tree.children[1] with pytest.raises(InvalidActionException): action_sequence.eval(GenerateToken("", "bar")) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(rule)) with pytest.raises(InvalidActionException): action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
def test_create_node(self): a = Node("def", [Field("name", "literal", Leaf("str", "foo"))]) seq = ActionSequence.create(a) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [ ("name", NodeType("literal", NodeConstraint.Token, False)) ])), GenerateToken("str", "foo") ] == seq.action_sequence
def test_eval_root(self): action_sequence = ActionSequence() assert action_sequence.head is None with pytest.raises(InvalidActionException): action_sequence = ActionSequence() action_sequence.eval(GenerateToken("kind", "")) with pytest.raises(InvalidActionException): action_sequence = ActionSequence() action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Node, False)), ("value", NodeType("args", NodeConstraint.Node, True))]) action_sequence.eval(ApplyRule(rule)) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [ApplyRule(rule)] == action_sequence.action_sequence assert action_sequence.parent(0) is None assert [[], []] == action_sequence._tree.children[0]
def test_generate(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, False)), ("arg0", NodeType("value", NodeConstraint.Token, False)), ("arg1", NodeType("value", NodeConstraint.Token, False))]) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("name", "f")) action_sequence.eval(GenerateToken("name", "_")) action_sequence.eval(GenerateToken("name", "0")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(expr)) action_sequence.eval(GenerateToken("value", "+")) action_sequence.eval(GenerateToken("value", "1")) action_sequence.eval(GenerateToken("value", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert action_sequence.head is None assert Node("def", [ Field("name", "value", [Leaf("name", "f"), Leaf("name", "_"), Leaf("name", "0")]), Field("body", "expr", [ Node("expr", [ Field("op", "value", Leaf("value", "+")), Field("arg0", "value", Leaf("value", "1")), Field("arg1", "value", Leaf("value", "2")) ]) ]) ]) == action_sequence.generate()
def enumerate_samples_per_state(self, rule_pred: torch.Tensor, token_pred: torch.Tensor, reference_pred: torch.Tensor, next_state: Environment, state: SamplerState[Environment], enumeration: Enumeration, k: Optional[int]) \ -> Generator[DuplicatedSamplerState[Environment], None, None]: def indices(pred: torch.Tensor): # 0 is unknown token if enumeration == Enumeration.Top: _, indices = torch.sort(pred[1:], descending=True) if k is not None: indices = indices[:k] for index in indices: yield index + 1, 1 elif enumeration == Enumeration.Random: indices = list(range(1, len(pred))) if k is not None: indices = indices[:k] for index in indices: yield index, 1 else: assert k is not None with logger.block("normalize_prob"): s = pred[1:].sum().item() if s < self.eps: return ps = (pred[1:] / s - self.eps).numpy() npred = [max(0, p) for p in ps] for i, n in enumerate(self.rng.multinomial(k, npred)): if n == 0: continue yield i + 1, n with logger.block("enumerate_samples_per_state"): head = state.state["action_sequence"].head assert head is not None head_field = \ cast(ExpandTreeRule, cast( ApplyRule, state.state["action_sequence"] .action_sequence[head.action] ).rule).children[head.field][1] if head_field.constraint == NodeConstraint.Token: # Generate token ref_ids = self.encoder.batch_encode_raw_value( [x.raw_value for x in state.state["reference"]]) tokens = list(self.encoder._token_encoder.vocab) + \ state.state["reference"] # the score will be merged into predefined token for i, ids in enumerate(ref_ids): for ref_id in ids: # merge token and reference pred # Add to unknown probability # if there is not the corresponding token. token_pred[ref_id] += reference_pred[i] if ref_id != 0: reference_pred[i] = 0.0 pred = torch.cat([token_pred, reference_pred], dim=0) # CloseVariadicFieldRule is a candidate if variadic fields if head_field.is_variadic: close_rule_idx = \ self.encoder._rule_encoder.encode( CloseVariadicFieldRule()) p = rule_pred[close_rule_idx].item() tokens.append(ApplyRule(CloseVariadicFieldRule())) pred = torch.cat([pred, torch.tensor([p])], dim=0) with logger.block("exclude_invalid_tokens"): # token for kind, idxes in self.token_kind_to_idx.items(): if kind is not None and \ not self.is_subtype(kind, head_field.type_name): pred[idxes] = 0.0 # reference for x, (p, token) in enumerate( zip(pred[len(token_pred):], tokens[len(token_pred):])): x += len(token_pred) if not isinstance(token, ApplyRule): if isinstance(token, Token): t = token.kind else: t = token[0] if t is not None and \ not self.is_subtype(t, head_field.type_name): pred[x] = 0.0 n_action = 0 for x, n in logger.iterable_block("sample-tokens", indices(pred)): # Finish enumeration if n_action == k: return p = pred[x].item() token = tokens[x] if isinstance(token, ApplyRule): action: Action = token elif isinstance(token, Token): action = GenerateToken(token.kind, token.raw_value) else: action = GenerateToken(token[0], token[1]) if p == 0.0: continue elif p < self.eps: lp = np.log(self.eps) else: lp = np.log(p) n_action += n next_state = next_state.clone() # TODO we may have to clear outputs next_state["action_sequence"] = \ LazyActionSequence( state.state["action_sequence"], action) yield DuplicatedSamplerState( SamplerState(state.score + lp, next_state), n) else: # Apply rule with logger.block("exclude_invalid_rules"): # expand tree rule for kind, idxes in self.rule_kind_to_idx.items(): if not (kind is not None and self.is_subtype( kind, head_field.type_name)): rule_pred[idxes] = 0.0 # CloseVariadicField idx = self.encoder._rule_encoder.encode( CloseVariadicFieldRule()) if not (head_field is not None and head_field.is_variadic): rule_pred[idx] = 0.0 n_rule = 0 for x, n in logger.iterable_block("sample-rule", indices(rule_pred)): # Finish enumeration if n_rule == k: return p = rule_pred[x].item() if p == 0.0: continue elif p < self.eps: lp = np.log(self.eps) else: lp = np.log(p) n_rule += n rule = self.encoder._rule_encoder.vocab[x] next_state = next_state.clone() next_state["action_sequence"] = \ LazyActionSequence( state.state["action_sequence"], ApplyRule(rule)) yield DuplicatedSamplerState( SamplerState(state.score + lp, next_state), n)