def test_create_leaf(self): seq = ActionSequence.create(Leaf("str", "t0 t1")) assert [ ApplyRule( ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [ ("root", NodeType(Root(), NodeConstraint.Token, False)) ])), GenerateToken("str", "t0 t1") ] == seq.action_sequence seq = ActionSequence.create( Node( "value", [Field("name", "str", [Leaf("str", "t0"), Leaf("str", "t1")])])) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule( NodeType("value", NodeConstraint.Node, False), [("name", NodeType("str", NodeConstraint.Token, True))])), GenerateToken("str", "t0"), GenerateToken("str", "t1"), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence
def forward(self, ground_truth: Code) -> ActionSequence: code = ground_truth ast = self.parser.parse(code) if ast is None: raise RuntimeError( f"cannot convert to ActionSequence: {ground_truth}") return cast(ActionSequence, ActionSequence.create(ast))
def get_samples(dataset: torch.utils.data.Dataset, parser: Parser[Any]) -> Samples: rules: List[Rule] = [] node_types = [] tokens: List[Tuple[str, str]] = [] for sample in dataset: ground_truth = sample["ground_truth"] ast = parser.parse(ground_truth) if ast is None: continue action_sequence = ActionSequence.create(ast) for action in action_sequence.action_sequence: if isinstance(action, ApplyRule): rule = action.rule if not isinstance(rule, CloseVariadicFieldRule): rules.append(rule) node_types.append(rule.parent) for _, child in rule.children: node_types.append(child) else: assert action.kind is not None tokens.append((action.kind, action.value)) return Samples(rules, node_types, tokens)
def get_samples(dataset: Dataset, parser: Parser[csgAST], reference: bool = False) -> Samples: rules: List[Rule] = [] node_types = [] srule = set() sntype = set() tokens = [("size", x) for x in dataset.size_candidates] tokens.extend([("length", x) for x in dataset.length_candidates]) tokens.extend([("degree", x) for x in dataset.degree_candidates]) if reference: # TODO use expander xs = [ Circle(1), Rectangle(1, 2), Translation(1, 1, Reference(0)), Rotation(45, Reference(1)), Union(Reference(0), Reference(1)), Difference(Reference(0), Reference(1)) ] else: xs = [ Circle(1), Rectangle(1, 2), Translation(1, 1, Circle(1)), Rotation(45, Circle(1)), Union(Circle(1), Circle(1)), Difference(Circle(1), Circle(1)) ] for x in xs: ast = parser.parse(x) if ast is None: continue action_sequence = ActionSequence.create(ast) for action in action_sequence.action_sequence: if isinstance(action, ApplyRule): rule = action.rule if not isinstance(rule, CloseVariadicFieldRule): if rule not in srule: rules.append(rule) srule.add(rule) if rule.parent not in sntype: node_types.append(rule.parent) sntype.add(rule.parent) for _, child in rule.children: if child not in sntype: node_types.append(child) sntype.add(child) tokens = list(set(tokens)) tokens.sort() return Samples(rules, node_types, tokens)
def test_create_node(self): a = Node("def", [Field("name", "literal", Leaf("str", "foo"))]) seq = ActionSequence.create(a) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [ ("name", NodeType("literal", NodeConstraint.Token, False)) ])), GenerateToken("str", "foo") ] == seq.action_sequence
def test_create_node_with_variadic_fields(self): a = Node( "list", [Field("elems", "literal", [Node("str", []), Node("str", [])])]) seq = ActionSequence.create(a) assert [ ApplyRule( ExpandTreeRule( NodeType(None, NodeConstraint.Node, False), [("root", NodeType(Root(), NodeConstraint.Node, False))])), ApplyRule( ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [ ("elems", NodeType("literal", NodeConstraint.Node, True)) ])), ApplyRule( ExpandTreeRule(NodeType("str", NodeConstraint.Node, False), [])), ApplyRule( ExpandTreeRule(NodeType("str", NodeConstraint.Node, False), [])), ApplyRule(CloseVariadicFieldRule()) ] == seq.action_sequence