예제 #1
0
    def test_create_leaf(self):
        seq = ActionSequence.create(Leaf("str", "t0 t1"))
        assert [
            ApplyRule(
                ExpandTreeRule(NodeType(None, NodeConstraint.Node, False), [
                    ("root", NodeType(Root(), NodeConstraint.Token, False))
                ])),
            GenerateToken("str", "t0 t1")
        ] == seq.action_sequence

        seq = ActionSequence.create(
            Node(
                "value",
                [Field("name", "str", [Leaf("str", "t0"),
                                       Leaf("str", "t1")])]))
        assert [
            ApplyRule(
                ExpandTreeRule(
                    NodeType(None, NodeConstraint.Node, False),
                    [("root", NodeType(Root(), NodeConstraint.Node, False))])),
            ApplyRule(
                ExpandTreeRule(
                    NodeType("value", NodeConstraint.Node, False),
                    [("name", NodeType("str", NodeConstraint.Token, True))])),
            GenerateToken("str", "t0"),
            GenerateToken("str", "t1"),
            ApplyRule(CloseVariadicFieldRule())
        ] == seq.action_sequence
예제 #2
0
 def forward(self, ground_truth: Code) -> ActionSequence:
     code = ground_truth
     ast = self.parser.parse(code)
     if ast is None:
         raise RuntimeError(
             f"cannot convert to ActionSequence: {ground_truth}")
     return cast(ActionSequence, ActionSequence.create(ast))
예제 #3
0
def get_samples(dataset: torch.utils.data.Dataset,
                parser: Parser[Any]) -> Samples:
    rules: List[Rule] = []
    node_types = []
    tokens: List[Tuple[str, str]] = []

    for sample in dataset:
        ground_truth = sample["ground_truth"]
        ast = parser.parse(ground_truth)
        if ast is None:
            continue
        action_sequence = ActionSequence.create(ast)
        for action in action_sequence.action_sequence:
            if isinstance(action, ApplyRule):
                rule = action.rule
                if not isinstance(rule, CloseVariadicFieldRule):
                    rules.append(rule)
                    node_types.append(rule.parent)
                    for _, child in rule.children:
                        node_types.append(child)
            else:
                assert action.kind is not None
                tokens.append((action.kind, action.value))

    return Samples(rules, node_types, tokens)
예제 #4
0
def get_samples(dataset: Dataset,
                parser: Parser[csgAST],
                reference: bool = False) -> Samples:
    rules: List[Rule] = []
    node_types = []
    srule = set()
    sntype = set()
    tokens = [("size", x) for x in dataset.size_candidates]
    tokens.extend([("length", x) for x in dataset.length_candidates])
    tokens.extend([("degree", x) for x in dataset.degree_candidates])

    if reference:
        # TODO use expander
        xs = [
            Circle(1),
            Rectangle(1, 2),
            Translation(1, 1, Reference(0)),
            Rotation(45, Reference(1)),
            Union(Reference(0), Reference(1)),
            Difference(Reference(0), Reference(1))
        ]
    else:
        xs = [
            Circle(1),
            Rectangle(1, 2),
            Translation(1, 1, Circle(1)),
            Rotation(45, Circle(1)),
            Union(Circle(1), Circle(1)),
            Difference(Circle(1), Circle(1))
        ]

    for x in xs:
        ast = parser.parse(x)
        if ast is None:
            continue
        action_sequence = ActionSequence.create(ast)
        for action in action_sequence.action_sequence:
            if isinstance(action, ApplyRule):
                rule = action.rule
                if not isinstance(rule, CloseVariadicFieldRule):
                    if rule not in srule:
                        rules.append(rule)
                        srule.add(rule)
                    if rule.parent not in sntype:
                        node_types.append(rule.parent)
                        sntype.add(rule.parent)
                    for _, child in rule.children:
                        if child not in sntype:
                            node_types.append(child)
                            sntype.add(child)
    tokens = list(set(tokens))
    tokens.sort()

    return Samples(rules, node_types, tokens)
예제 #5
0
 def test_create_node(self):
     a = Node("def", [Field("name", "literal", Leaf("str", "foo"))])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("def", NodeConstraint.Node, False), [
                 ("name", NodeType("literal", NodeConstraint.Token, False))
             ])),
         GenerateToken("str", "foo")
     ] == seq.action_sequence
예제 #6
0
 def test_create_node_with_variadic_fields(self):
     a = Node(
         "list",
         [Field("elems", "literal",
                [Node("str", []), Node("str", [])])])
     seq = ActionSequence.create(a)
     assert [
         ApplyRule(
             ExpandTreeRule(
                 NodeType(None, NodeConstraint.Node, False),
                 [("root", NodeType(Root(), NodeConstraint.Node, False))])),
         ApplyRule(
             ExpandTreeRule(NodeType("list", NodeConstraint.Node, False), [
                 ("elems", NodeType("literal", NodeConstraint.Node, True))
             ])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(
             ExpandTreeRule(NodeType("str", NodeConstraint.Node, False),
                            [])),
         ApplyRule(CloseVariadicFieldRule())
     ] == seq.action_sequence