def test_create_leaf(self): seq = ActionSequence.create(Leaf("str", "t0 t1")) assert [ ApplyRule( ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [ ("root", NodeType(Root(), NodeConstraint.Token, False)) ])), GenerateToken("str", "t0 t1") ] == seq.action_sequence seq = ActionSequence.create( Node( "value", [Field("name", "str", [Leaf("str", "t0"), Leaf("str", "t1")])])) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule( NodeType("value", NodeConstraint.Node, False), [("name", NodeType("str", NodeConstraint.Token, True))])), GenerateToken("str", "t0"), GenerateToken("str", "t1"), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence
def test_encode_parent(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(GenerateToken("", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) parent = encoder.encode_parent(action_sequence) assert np.array_equal([[-1, -1, -1, -1], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0], [1, 2, 0, 1]], parent.numpy())
def test_encode_invalid_sequence(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, False)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, True) ], [("", "f")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert encoder.encode_action(action_sequence, [Token("", "2", "2")]) is None
def test_encode_completed_sequence(self): none = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False), []) encoder = ActionSequenceEncoder( Samples([none], [NodeType("value", NodeConstraint.Node, False)], [("", "f")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(none)) action = encoder.encode_action(action_sequence, [Token("", "1", "1")]) parent = encoder.encode_parent(action_sequence) assert np.array_equal([[-1, 2, -1, -1], [-1, -1, -1, -1]], action.numpy()) assert np.array_equal([[-1, -1, -1, -1], [-1, -1, -1, -1]], parent.numpy())
def encode_path(self, action_sequence: ActionSequence, max_depth: int) \ -> torch.Tensor: """ Return the tensor encoding the each action Parameters ---------- action_sequence: action_sequence The action_sequence containing action sequence to be encoded max_depth: int Returns ------- torch.Tensor The encoded tensor. The shape of tensor is (len(action_sequence), max_depth). [i, :] encodes the path from the root node to i-th node. Each node represented by the rule id. The padding value is -1. """ L = len(action_sequence.action_sequence) retval = torch.ones(L, max_depth).long() * -1 for i in range(L): parent_opt = action_sequence.parent(i) if parent_opt is not None: p = action_sequence.action_sequence[parent_opt.action] if isinstance(p, ApplyRule): retval[i, 0] = self._rule_encoder.encode(p.rule) retval[i, 1:] = retval[parent_opt.action, :max_depth - 1] return retval
def test_encode_empty_sequence(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, False)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, False)), ("arg0", NodeType("value", NodeConstraint.Token, False)), ("arg1", NodeType("value", NodeConstraint.Token, False))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, False), NodeType("expr", NodeConstraint.Node, False) ], [("", "f")]), 0) action_sequence = ActionSequence() action = encoder.encode_action(action_sequence, [Token("", "1", "1")]) parent = encoder.encode_parent(action_sequence) d, m = encoder.encode_tree(action_sequence) assert np.array_equal([[-1, -1, -1, -1]], action.numpy()) assert np.array_equal([[-1, -1, -1, -1]], parent.numpy()) assert np.array_equal(np.zeros((0, )), d.numpy()) assert np.array_equal(np.zeros((0, 0)), m.numpy())
def get_samples(dataset: torch.utils.data.Dataset, parser: Parser[Any]) -> Samples: rules: List[Rule] = [] node_types = [] tokens: List[Tuple[str, str]] = [] for sample in dataset: ground_truth = sample["ground_truth"] ast = parser.parse(ground_truth) if ast is None: continue action_sequence = ActionSequence.create(ast) for action in action_sequence.action_sequence: if isinstance(action, ApplyRule): rule = action.rule if not isinstance(rule, CloseVariadicFieldRule): rules.append(rule) node_types.append(rule.parent) for _, child in rule.children: node_types.append(child) else: assert action.kind is not None tokens.append((action.kind, action.value)) return Samples(rules, node_types, tokens)
def forward(self, ground_truth: Code) -> ActionSequence: code = ground_truth ast = self.parser.parse(code) if ast is None: raise RuntimeError( f"cannot convert to ActionSequence: {ground_truth}") return cast(ActionSequence, ActionSequence.create(ast))
def encode_tree(self, action_sequence: ActionSequence) \ -> Union[torch.Tensor, torch.Tensor]: """ Return the tensor adjacency matrix of the action sequence Parameters ---------- action_sequence: action_sequence The action_sequence containing action sequence to be encoded Returns ------- depth: torch.Tensor The depth of each action. The shape is (len(action_sequence),). adjacency_matrix: torch.Tensor The encoded tensor. The shape of tensor is (len(action_sequence), len(action_sequence)). If i th action is a parent of j th action, (i, j) element will be 1. the element will be 0 otherwise. """ L = len(action_sequence.action_sequence) depth = torch.zeros(L) m = torch.zeros(L, L) for i in range(L): p = action_sequence.parent(i) if p is not None: depth[i] = depth[p.action] + 1 m[p.action, i] = 1 return depth, m
def initialize(self, input: Input) -> Environment: self.module.encoder.eval() state_list = self.transform_input(input) state_tensor = self.collate.collate([state_list]) state_tensor = self._to(state_tensor) with torch.no_grad(), logger.block("encode_state"): state_tensor = self.module.encoder(state_tensor) state = self.collate.split(state_tensor)[0] # Add initial rule action_sequence = ActionSequence() action_sequence.eval( ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))]))) state["action_sequence"] = action_sequence return state
def get_samples(dataset: Dataset, parser: Parser[csgAST], reference: bool = False) -> Samples: rules: List[Rule] = [] node_types = [] srule = set() sntype = set() tokens = [("size", x) for x in dataset.size_candidates] tokens.extend([("length", x) for x in dataset.length_candidates]) tokens.extend([("degree", x) for x in dataset.degree_candidates]) if reference: # TODO use expander xs = [ Circle(1), Rectangle(1, 2), Translation(1, 1, Reference(0)), Rotation(45, Reference(1)), Union(Reference(0), Reference(1)), Difference(Reference(0), Reference(1)) ] else: xs = [ Circle(1), Rectangle(1, 2), Translation(1, 1, Circle(1)), Rotation(45, Circle(1)), Union(Circle(1), Circle(1)), Difference(Circle(1), Circle(1)) ] for x in xs: ast = parser.parse(x) if ast is None: continue action_sequence = ActionSequence.create(ast) for action in action_sequence.action_sequence: if isinstance(action, ApplyRule): rule = action.rule if not isinstance(rule, CloseVariadicFieldRule): if rule not in srule: rules.append(rule) srule.add(rule) if rule.parent not in sntype: node_types.append(rule.parent) sntype.add(rule.parent) for _, child in rule.children: if child not in sntype: node_types.append(child) sntype.add(child) tokens = list(set(tokens)) tokens.sort() return Samples(rules, node_types, tokens)
def test_encode_tree(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) d, m = encoder.encode_tree(action_sequence) assert np.array_equal([0, 1, 1], d.numpy()) assert np.array_equal([[0, 1, 1], [0, 0, 0], [0, 0, 0]], m.numpy())
def test_generate_ignore_root_type(self): action_sequence = ActionSequence() action_sequence.eval( ApplyRule( ExpandTreeRule( NodeType(Root(), NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))]))) action_sequence.eval( ApplyRule( ExpandTreeRule(NodeType("op", NodeConstraint.Node, False), []))) assert Node("op", []) == action_sequence.generate()
def test_create_node(self): a = Node("def", [Field("name", "literal", Leaf("str", "foo"))]) seq = ActionSequence.create(a) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [ ("name", NodeType("literal", NodeConstraint.Token, False)) ])), GenerateToken("str", "foo") ] == seq.action_sequence
def test_invalid_close_variadic_field_rule(self): rule = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("elems", NodeType("value", NodeConstraint.Node, False))]) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(rule)) with pytest.raises(InvalidActionException): action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
def decode(self, tensor: torch.LongTensor, reference: List[Token]) \ -> Optional[ActionSequence]: """ Return the action sequence corresponding to the tensor Parameters ---------- tensor: torch.LongTensor The encoded tensor with the shape of (len(action_sequence), 3). Each action will be encoded by the tuple of (ID of the applied rule, ID of the inserted token, the index of the word copied from the reference). The padding value should be -1. reference Returns ------- Optional[action_sequence] The action sequence corresponding to the tensor None if the action sequence cannot be generated. """ retval = ActionSequence() for i in range(tensor.shape[0]): if tensor[i, 0] > 0: # ApplyRule rule = self._rule_encoder.decode(tensor[i, 0]) retval.eval(ApplyRule(rule)) elif tensor[i, 1] > 0: # GenerateToken kind, value = self._token_encoder.decode(tensor[i, 1]) retval.eval(GenerateToken(kind, value)) elif tensor[i, 2] >= 0: # GenerateToken (Copy) index = int(tensor[i, 2].numpy()) if index >= len(reference): logger.debug("reference index is out-of-bounds") return None token = reference[index] retval.eval(GenerateToken(token.kind, token.raw_value)) else: logger.debug("invalid actions") return None return retval
def test_create_node_with_variadic_fields(self): a = Node( "list", [Field("elems", "literal", [Node("str", []), Node("str", [])])]) seq = ActionSequence.create(a) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [ ("elems", NodeType("literal", NodeConstraint.Node, True)) ])), ApplyRule( ExpandTreeRule(NodeType("str", NodeConstraint.Node, False), [])), ApplyRule( ExpandTreeRule(NodeType("str", NodeConstraint.Node, False), [])), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence
def test_clone(self): action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("elems", NodeType("expr", NodeConstraint.Node, True))]) action_sequence.eval(ApplyRule(rule)) action_sequence2 = action_sequence.clone() assert action_sequence.generate() == action_sequence2.generate() action_sequence2.eval(ApplyRule(rule)) assert \ action_sequence._tree.children != action_sequence2._tree.children assert \ action_sequence._tree.parent != action_sequence2._tree.parent assert \ action_sequence.action_sequence != action_sequence2.action_sequence assert action_sequence._head_action_index != \ action_sequence2._head_action_index assert action_sequence._head_children_index != \ action_sequence2._head_children_index assert action_sequence.generate() != action_sequence2.generate()
def test_generate_token(self): action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, False)), ("value", NodeType("args", NodeConstraint.Node, True))]) action_sequence.eval(ApplyRule(rule)) action_sequence.eval(GenerateToken("", "foo")) assert 0 == action_sequence.head.action assert 1 == action_sequence.head.field assert [1] == action_sequence._tree.children[0][0] assert [ApplyRule(rule), GenerateToken("", "foo")] == action_sequence.action_sequence assert Parent(0, 0) == action_sequence.parent(1) assert [] == action_sequence._tree.children[1] with pytest.raises(InvalidActionException): action_sequence.eval(GenerateToken("", "bar")) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(rule)) with pytest.raises(InvalidActionException): action_sequence.eval(ApplyRule(CloseVariadicFieldRule()))
def test_decode(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, True)), ("arg0", NodeType("value", NodeConstraint.Token, True)), ("arg1", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, False) ], [("", "f")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) expected_action_sequence = ActionSequence() expected_action_sequence.eval(ApplyRule(funcdef)) expected_action_sequence.eval(GenerateToken("", "f")) expected_action_sequence.eval(GenerateToken("", "1")) expected_action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) result = encoder.decode( encoder.encode_action(action_sequence, [Token(None, "1", "1")])[:-1, 1:], [Token(None, "1", "1")]) assert \ expected_action_sequence.action_sequence == result.action_sequence
def test_eval_root(self): action_sequence = ActionSequence() assert action_sequence.head is None with pytest.raises(InvalidActionException): action_sequence = ActionSequence() action_sequence.eval(GenerateToken("kind", "")) with pytest.raises(InvalidActionException): action_sequence = ActionSequence() action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Node, False)), ("value", NodeType("args", NodeConstraint.Node, True))]) action_sequence.eval(ApplyRule(rule)) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [ApplyRule(rule)] == action_sequence.action_sequence assert action_sequence.parent(0) is None assert [[], []] == action_sequence._tree.children[0]
def test_generate(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("op", NodeType("value", NodeConstraint.Token, False)), ("arg0", NodeType("value", NodeConstraint.Token, False)), ("arg1", NodeType("value", NodeConstraint.Token, False))]) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("name", "f")) action_sequence.eval(GenerateToken("name", "_")) action_sequence.eval(GenerateToken("name", "0")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(expr)) action_sequence.eval(GenerateToken("value", "+")) action_sequence.eval(GenerateToken("value", "1")) action_sequence.eval(GenerateToken("value", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert action_sequence.head is None assert Node("def", [ Field("name", "value", [Leaf("name", "f"), Leaf("name", "_"), Leaf("name", "0")]), Field("body", "expr", [ Node("expr", [ Field("op", "value", Leaf("value", "+")), Field("arg0", "value", Leaf("value", "1")), Field("arg1", "value", Leaf("value", "2")) ]) ]) ]) == action_sequence.generate()
def test_variadic_field(self): action_sequence = ActionSequence() rule = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("elems", NodeType("value", NodeConstraint.Node, True))]) rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False), []) action_sequence.eval(ApplyRule(rule)) action_sequence.eval(ApplyRule(rule0)) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [1] == action_sequence._tree.children[0][0] assert [ApplyRule(rule), ApplyRule(rule0)] == action_sequence.action_sequence assert Parent(0, 0) == action_sequence.parent(1) assert [] == action_sequence._tree.children[1] action_sequence.eval(ApplyRule(rule0)) assert 0 == action_sequence.head.action assert 0 == action_sequence.head.field assert [1, 2] == action_sequence._tree.children[0][0] assert [ApplyRule(rule), ApplyRule(rule0), ApplyRule(rule0)] == action_sequence.action_sequence action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) assert action_sequence.head is None action_sequence = ActionSequence() rule1 = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("elems", NodeType("value", NodeConstraint.Node, True)), ("name", NodeType("value", NodeConstraint.Node, False))]) rule0 = ExpandTreeRule(NodeType("value", NodeConstraint.Node, False), []) action_sequence.eval(ApplyRule(rule1)) action_sequence.eval(ApplyRule(rule0)) action_sequence.eval(ApplyRule(CloseVariadicFieldRule)) assert 0 == action_sequence.head.action assert 1 == action_sequence.head.field
def test_encode_path(self): funcdef = ExpandTreeRule( NodeType("def", NodeConstraint.Node, False), [("name", NodeType("value", NodeConstraint.Token, True)), ("body", NodeType("expr", NodeConstraint.Node, True))]) expr = ExpandTreeRule( NodeType("expr", NodeConstraint.Node, False), [("constant", NodeType("value", NodeConstraint.Token, True))]) encoder = ActionSequenceEncoder( Samples([funcdef, expr], [ NodeType("def", NodeConstraint.Node, False), NodeType("value", NodeConstraint.Token, True), NodeType("expr", NodeConstraint.Node, True) ], [("", "f"), ("", "2")]), 0) action_sequence = ActionSequence() action_sequence.eval(ApplyRule(funcdef)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(GenerateToken("", "1")) action_sequence.eval(GenerateToken("", "2")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(expr)) action_sequence.eval(GenerateToken("", "f")) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) action_sequence.eval(ApplyRule(CloseVariadicFieldRule())) path = encoder.encode_path(action_sequence, 2) assert np.array_equal( np.array( [ [-1, -1], # funcdef [2, -1], # f [2, -1], # 1 [2, -1], # 2 [2, -1], # CloseVariadicField [2, -1], # expr [3, 2], # f [3, 2], # CloseVariadicField [2, -1], # CloseVariadicField ], dtype=np.long), path.numpy()) path = encoder.encode_path(action_sequence, 1) assert np.array_equal( np.array( [ [-1], # funcdef [2], # f [2], # 1 [2], # 2 [2], # CloseVariadicField [2], # expr [3], # f [3], # CloseVariadicField [2], # CloseVariadicField ], dtype=np.long), path.numpy())
def encode_action(self, action_sequence: ActionSequence, reference: List[Token]) \ -> Optional[torch.Tensor]: """ Return the tensor encoded the action sequence Parameters ---------- action_sequence: action_sequence The action_sequence containing action sequence to be encoded reference Returns ------- Optional[torch.Tensor] The encoded tensor. The shape of tensor is (len(action_sequence) + 1, 4). Each action will be encoded by the tuple of (ID of the node types, ID of the applied rule, ID of the inserted token, the index of the word copied from the reference. The padding value should be -1. None if the action sequence cannot be encoded. """ reference_value = [token.raw_value for token in reference] action = \ torch.ones(len(action_sequence.action_sequence) + 1, 4).long() \ * -1 for i in range(len(action_sequence.action_sequence)): a = action_sequence.action_sequence[i] parent = action_sequence.parent(i) if parent is not None: parent_action = \ cast(ApplyRule, action_sequence.action_sequence[parent.action]) parent_rule = cast(ExpandTreeRule, parent_action.rule) action[i, 0] = self._node_type_encoder.encode( parent_rule.children[parent.field][1]) if isinstance(a, ApplyRule): rule = a.rule action[i, 1] = self._rule_encoder.encode(rule) else: encoded_token = \ int(self._token_encoder.encode((a.kind, a.value)).numpy()) if encoded_token != 0: action[i, 2] = encoded_token # Unknown token if a.value in reference_value: # TODO use kind in reference action[i, 3] = \ reference_value.index(cast(str, a.value)) if encoded_token == 0 and \ a.value not in reference_value: logger.debug("cannot encode token") return None head = action_sequence.head length = len(action_sequence.action_sequence) if head is not None: head_action = \ cast(ApplyRule, action_sequence.action_sequence[head.action]) head_rule = cast(ExpandTreeRule, head_action.rule) action[length, 0] = self._node_type_encoder.encode( head_rule.children[head.field][1]) return action