Exemple #1
0
    def test_execute(self):
        ref0 = Rectangle(1, 1)
        ref1 = Rectangle(3, 1)
        ref2 = Difference(Reference(0), Reference(1))
        ref3 = Union(Rectangle(1, 1), Reference(2))
        interpreter = Interpreter(3, 3, 1, False)
        state = interpreter.create_state([None])

        state = interpreter.execute(ref0, state)
        assert state.history == [ref0]
        assert set(state.environment.keys()) == set([Reference(0)])
        assert state.type_environment[Reference(0)] == "Rectangle"
        assert show(state.environment[Reference(0)][0]) == "   \n # \n   \n"
        assert state.context == [None]

        state = interpreter.execute(ref1, state)
        assert state.history == [ref0, ref1]
        assert set(state.environment.keys()) == set(
            [Reference(0), Reference(1)])
        assert show(state.environment[Reference(1)][0]) == "   \n###\n   \n"
        assert state.context == [None]

        state = interpreter.execute(ref2, state)
        assert state.history == [ref0, ref1, ref2]
        assert set(state.environment.keys()) == \
            set([Reference(0), Reference(1), Reference(2)])
        assert show(state.environment[Reference(2)][0]) == "   \n# #\n   \n"
        assert state.context == [None]

        state = interpreter.execute(ref3, state)
        assert state.history == [ref0, ref1, ref2, ref3]
        assert set(state.environment.keys()) == \
            set([Reference(0), Reference(1), Reference(2), Reference(3)])
        assert show(state.environment[Reference(3)][0]) == "   \n###\n   \n"
        assert state.context == [None]
Exemple #2
0
def get_samples(dataset: Dataset,
                parser: Parser[csgAST],
                reference: bool = False) -> Samples:
    rules: List[Rule] = []
    node_types = []
    srule = set()
    sntype = set()
    tokens = [("size", x) for x in dataset.size_candidates]
    tokens.extend([("length", x) for x in dataset.length_candidates])
    tokens.extend([("degree", x) for x in dataset.degree_candidates])

    if reference:
        # TODO use expander
        xs = [
            Circle(1),
            Rectangle(1, 2),
            Translation(1, 1, Reference(0)),
            Rotation(45, Reference(1)),
            Union(Reference(0), Reference(1)),
            Difference(Reference(0), Reference(1))
        ]
    else:
        xs = [
            Circle(1),
            Rectangle(1, 2),
            Translation(1, 1, Circle(1)),
            Rotation(45, Circle(1)),
            Union(Circle(1), Circle(1)),
            Difference(Circle(1), Circle(1))
        ]

    for x in xs:
        ast = parser.parse(x)
        if ast is None:
            continue
        action_sequence = ActionSequence.create(ast)
        for action in action_sequence.action_sequence:
            if isinstance(action, ApplyRule):
                rule = action.rule
                if not isinstance(rule, CloseVariadicFieldRule):
                    if rule not in srule:
                        rules.append(rule)
                        srule.add(rule)
                    if rule.parent not in sntype:
                        node_types.append(rule.parent)
                        sntype.add(rule.parent)
                    for _, child in rule.children:
                        if child not in sntype:
                            node_types.append(child)
                            sntype.add(child)
    tokens = list(set(tokens))
    tokens.sort()

    return Samples(rules, node_types, tokens)
Exemple #3
0
    def test_execute_with_multiple_inputs(self):
        ref0 = Rectangle(1, 1)
        interpreter = Interpreter(3, 3, 1, False)
        state = interpreter.create_state([None, None])

        state = interpreter.execute(ref0, state)
        assert len(state.environment[Reference(0)]) == 2
Exemple #4
0
    def test_draw_same_objects(self):
        ref0 = Rectangle(1, 1)
        ref1 = Rectangle(1, 1)
        ref2 = Rotation(180, Reference(0))
        interpreter = Interpreter(3, 3, 1, True)
        state = interpreter.create_state([None])

        state = interpreter.execute(ref0, state)
        assert set(state.environment.keys()) == set([Reference(0)])

        state = interpreter.execute(ref1, state)
        assert set(state.environment.keys()) == set(
            [Reference(0), Reference(1)])

        state = interpreter.execute(ref2, state)
        assert set(state.environment.keys()) == set(
            [Reference(1), Reference(2)])
Exemple #5
0
    def test_delete_used_variable(self):
        ref0 = Rectangle(1, 1)
        ref1 = Rectangle(3, 1)
        ref2 = Difference(Reference(0), Reference(1))
        ref3 = Union(Rectangle(1, 1), Reference(2))
        interpreter = Interpreter(3, 3, 1, True)
        state = interpreter.create_state([None])

        state = interpreter.execute(ref0, state)
        assert set(state.environment.keys()) == set([Reference(0)])

        state = interpreter.execute(ref1, state)
        assert set(state.environment.keys()) == set(
            [Reference(0), Reference(1)])

        state = interpreter.execute(ref2, state)
        assert set(state.environment.keys()) == set([Reference(2)])

        state = interpreter.execute(ref3, state)
        assert set(state.environment.keys()) == set([Reference(3)])
Exemple #6
0
    def sample_ast(self, rng: np.random.RandomState, n_object: int) -> AST:
        objects: Dict[int, AST] = {}
        for i, t in enumerate(rng.choice(self.leaf_candidates, n_object)):
            if t == "Circle":
                objects[i] = Circle(rng.choice(self.size_candidates))
            elif t == "Rectangle":
                objects[i] = Rectangle(rng.choice(self.size_candidates),
                                       rng.choice(self.size_candidates))
            else:
                raise Exception(f"Invalid type: {t}")
        ops = {}
        for i, t in enumerate(rng.choice(self.branch_candidates,
                                         n_object - 1)):
            ops[i] = t
        n_node = rng.randint(0, len(ops) + 2)
        n = len(ops)
        for i, t in enumerate(rng.choice(self.node_candidates, n_node)):
            ops[i + n] = t

        while len(objects) > 1 and len(ops) != 0:
            op_key = rng.choice(list(ops.keys()))
            op = ops.pop(op_key)
            obj0_key = rng.choice(list(objects.keys()))
            obj0 = objects.pop(obj0_key)
            if op == "Translation":
                objects[obj0_key] = Translation(
                    rng.choice(self.length_candidates),
                    rng.choice(self.length_candidates), obj0)
            elif op == "Rotation":
                objects[obj0_key] = Rotation(
                    rng.choice(self.degree_candidates), obj0)
            else:
                obj1_key = rng.choice(list(objects.keys()))
                obj1 = objects.pop(obj1_key)
                if op == "Union":
                    objects[obj0_key] = Union(obj0, obj1)
                elif op == "Difference":
                    objects[obj0_key] = Difference(obj0, obj1)
                else:
                    raise Exception(f"Invalid type: {t}")
        return list(objects.values())[0]
Exemple #7
0
 def unparse(self, code: AST) -> Optional[csgAST]:
     if isinstance(code, Node):
         fields = {field.name: field.value for field in code.fields}
         if code.get_type_name() == "Circle":
             if isinstance(fields["r"], Leaf):
                 return Circle(fields["r"].value)
             else:
                 return None
         elif code.get_type_name() == "Rectangle":
             if isinstance(fields["w"], Leaf) and \
                     isinstance(fields["h"], Leaf):
                 return Rectangle(fields["w"].value, fields["h"].value)
             else:
                 return None
         elif code.get_type_name() == "Translation":
             if not isinstance(fields["child"], AST):
                 return None
             child = self.unparse(fields["child"])
             if child is None:
                 return None
             else:
                 if isinstance(fields["x"], Leaf) and \
                         isinstance(fields["y"], Leaf):
                     return Translation(fields["x"].value,
                                        fields["y"].value, child)
                 else:
                     return None
         elif code.get_type_name() == "Rotation":
             if not isinstance(fields["child"], AST):
                 return None
             child = self.unparse(fields["child"])
             if child is None:
                 return None
             else:
                 if isinstance(fields["theta"], Leaf):
                     return Rotation(
                         fields["theta"].value,
                         child,
                     )
                 else:
                     return None
         elif code.get_type_name() == "Union":
             if not isinstance(fields["a"], AST):
                 return None
             if not isinstance(fields["b"], AST):
                 return None
             a = self.unparse(fields["a"])
             if a is None:
                 return None
             b = self.unparse(fields["b"])
             if b is None:
                 return None
             return Union(a, b)
         elif code.get_type_name() == "Difference":
             if not isinstance(fields["a"], AST):
                 return None
             if not isinstance(fields["b"], AST):
                 return None
             a = self.unparse(fields["a"])
             if a is None:
                 return None
             b = self.unparse(fields["b"])
             if b is None:
                 return None
             return Difference(a, b)
         return None
     elif isinstance(code, Leaf):
         return cast(csgAST, code.value)
     return None
    def pretrain(self, output_dir):
        dataset = Dataset(4, 1, 2, 1, 45, seed=0)
        """
        """
        train_dataset = ListDataset([
            Environment(
                {"ground_truth": Circle(1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 2)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rotation(45, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Translation(1, 1, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Circle(1), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Union(Rectangle(1, 2), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Rectangle(1, 1), Circle(1))},
                set(["ground_truth"]),
            ),
        ])

        with tempfile.TemporaryDirectory() as tmpdir:
            interpreter = self.interpreter()
            train_dataset = data_transform(
                train_dataset,
                Apply(
                    module=AddTestCases(interpreter),
                    in_keys=["ground_truth"],
                    out_key="test_cases",
                    is_out_supervision=False,
                ))
            encoder = self.prepare_encoder(dataset, Parser())

            collate = Collate(
                test_case_tensor=CollateOptions(False, 0, 0),
                variables_tensor=CollateOptions(True, 0, 0),
                previous_actions=CollateOptions(True, 0, -1),
                hidden_state=CollateOptions(False, 0, 0),
                state=CollateOptions(False, 0, 0),
                ground_truth_actions=CollateOptions(True, 0, -1)
            )
            collate_fn = Sequence(OrderedDict([
                ("to_episode", Map(self.to_episode(encoder,
                                                   interpreter))),
                ("flatten", Flatten()),
                ("transform", Map(self.transform(
                    encoder, interpreter, Parser()))),
                ("collate", collate.collate)
            ]))

            model = self.prepare_model(encoder)
            optimizer = self.prepare_optimizer(model)
            train_supervised(
                tmpdir, output_dir,
                train_dataset, model, optimizer,
                torch.nn.Sequential(OrderedDict([
                    ("loss",
                     Apply(
                         module=Loss(
                             reduction="sum",
                         ),
                         in_keys=[
                             "rule_probs",
                             "token_probs",
                             "reference_probs",
                             "ground_truth_actions",
                         ],
                         out_key="action_sequence_loss",
                     )),
                    ("normalize",  # divided by batch_size
                     Apply(
                         [("action_sequence_loss", "lhs")],
                         "loss",
                         mlprogram.nn.Function(Div()),
                         constants={"rhs": 1})),
                    ("pick",
                     mlprogram.nn.Function(
                         Pick("loss")))
                ])),
                None, "score",
                collate_fn,
                1, Epoch(100), evaluation_interval=Epoch(10),
                snapshot_interval=Epoch(100)
            )
        return encoder, train_dataset
Exemple #9
0
 def test_difference(self):
     code = Difference(Rectangle(1, 1), Rectangle(3, 1))
     interpreter = Interpreter(3, 3, 1, False)
     assert "   \n# #\n   \n" == show(interpreter.eval(code, [None])[0])
Exemple #10
0
 def test_union(self):
     code = Union(Rectangle(3, 1), Rectangle(1, 3))
     interpreter = Interpreter(3, 3, 1, False)
     assert " # \n###\n # \n" == show(interpreter.eval(code, [None])[0])
Exemple #11
0
 def test_rotation(self):
     code = Rotation(45, Rectangle(4, 1))
     interpreter = Interpreter(3, 3, 1, False)
     assert "  #\n # \n#  \n" == show(interpreter.eval(code, [None])[0])
Exemple #12
0
 def test_translation(self):
     code = Translation(2, 1, Rectangle(1, 3))
     interpreter = Interpreter(5, 5, 1, False)
     assert "    #\n    #\n    #\n     \n     \n" == \
         show(interpreter.eval(code, [None])[0])
Exemple #13
0
 def test_rectangle(self):
     code = Rectangle(1, 3)
     interpreter = Interpreter(3, 3, 1, False)
     assert " # \n # \n # \n" == show(interpreter.eval(code, [None])[0])
Exemple #14
0
 def test_unparse_rectangle(self, parser):
     assert Rectangle(1, 2) == \
         parser.unparse(parser.parse(Rectangle(1, 2)))
Exemple #15
0
 def test_parse_rectangle(self, parser):
     assert Node("Rectangle", [
         Field("w", "size", Leaf("size", 1)),
         Field("h", "size", Leaf("size", 2))
     ]) == parser.parse(Rectangle(1, 2))