コード例 #1
0
ファイル: functions.py プロジェクト: nashid/mlprogram
def get_samples(dataset: Dataset,
                parser: Parser[csgAST],
                reference: bool = False) -> Samples:
    rules: List[Rule] = []
    node_types = []
    srule = set()
    sntype = set()
    tokens = [("size", x) for x in dataset.size_candidates]
    tokens.extend([("length", x) for x in dataset.length_candidates])
    tokens.extend([("degree", x) for x in dataset.degree_candidates])

    if reference:
        # TODO use expander
        xs = [
            Circle(1),
            Rectangle(1, 2),
            Translation(1, 1, Reference(0)),
            Rotation(45, Reference(1)),
            Union(Reference(0), Reference(1)),
            Difference(Reference(0), Reference(1))
        ]
    else:
        xs = [
            Circle(1),
            Rectangle(1, 2),
            Translation(1, 1, Circle(1)),
            Rotation(45, Circle(1)),
            Union(Circle(1), Circle(1)),
            Difference(Circle(1), Circle(1))
        ]

    for x in xs:
        ast = parser.parse(x)
        if ast is None:
            continue
        action_sequence = ActionSequence.create(ast)
        for action in action_sequence.action_sequence:
            if isinstance(action, ApplyRule):
                rule = action.rule
                if not isinstance(rule, CloseVariadicFieldRule):
                    if rule not in srule:
                        rules.append(rule)
                        srule.add(rule)
                    if rule.parent not in sntype:
                        node_types.append(rule.parent)
                        sntype.add(rule.parent)
                    for _, child in rule.children:
                        if child not in sntype:
                            node_types.append(child)
                            sntype.add(child)
    tokens = list(set(tokens))
    tokens.sort()

    return Samples(rules, node_types, tokens)
コード例 #2
0
    def sample_ast(self, rng: np.random.RandomState, n_object: int) -> AST:
        objects: Dict[int, AST] = {}
        for i, t in enumerate(rng.choice(self.leaf_candidates, n_object)):
            if t == "Circle":
                objects[i] = Circle(rng.choice(self.size_candidates))
            elif t == "Rectangle":
                objects[i] = Rectangle(rng.choice(self.size_candidates),
                                       rng.choice(self.size_candidates))
            else:
                raise Exception(f"Invalid type: {t}")
        ops = {}
        for i, t in enumerate(rng.choice(self.branch_candidates,
                                         n_object - 1)):
            ops[i] = t
        n_node = rng.randint(0, len(ops) + 2)
        n = len(ops)
        for i, t in enumerate(rng.choice(self.node_candidates, n_node)):
            ops[i + n] = t

        while len(objects) > 1 and len(ops) != 0:
            op_key = rng.choice(list(ops.keys()))
            op = ops.pop(op_key)
            obj0_key = rng.choice(list(objects.keys()))
            obj0 = objects.pop(obj0_key)
            if op == "Translation":
                objects[obj0_key] = Translation(
                    rng.choice(self.length_candidates),
                    rng.choice(self.length_candidates), obj0)
            elif op == "Rotation":
                objects[obj0_key] = Rotation(
                    rng.choice(self.degree_candidates), obj0)
            else:
                obj1_key = rng.choice(list(objects.keys()))
                obj1 = objects.pop(obj1_key)
                if op == "Union":
                    objects[obj0_key] = Union(obj0, obj1)
                elif op == "Difference":
                    objects[obj0_key] = Difference(obj0, obj1)
                else:
                    raise Exception(f"Invalid type: {t}")
        return list(objects.values())[0]
コード例 #3
0
ファイル: parser.py プロジェクト: nashid/mlprogram
 def unparse(self, code: AST) -> Optional[csgAST]:
     if isinstance(code, Node):
         fields = {field.name: field.value for field in code.fields}
         if code.get_type_name() == "Circle":
             if isinstance(fields["r"], Leaf):
                 return Circle(fields["r"].value)
             else:
                 return None
         elif code.get_type_name() == "Rectangle":
             if isinstance(fields["w"], Leaf) and \
                     isinstance(fields["h"], Leaf):
                 return Rectangle(fields["w"].value, fields["h"].value)
             else:
                 return None
         elif code.get_type_name() == "Translation":
             if not isinstance(fields["child"], AST):
                 return None
             child = self.unparse(fields["child"])
             if child is None:
                 return None
             else:
                 if isinstance(fields["x"], Leaf) and \
                         isinstance(fields["y"], Leaf):
                     return Translation(fields["x"].value,
                                        fields["y"].value, child)
                 else:
                     return None
         elif code.get_type_name() == "Rotation":
             if not isinstance(fields["child"], AST):
                 return None
             child = self.unparse(fields["child"])
             if child is None:
                 return None
             else:
                 if isinstance(fields["theta"], Leaf):
                     return Rotation(
                         fields["theta"].value,
                         child,
                     )
                 else:
                     return None
         elif code.get_type_name() == "Union":
             if not isinstance(fields["a"], AST):
                 return None
             if not isinstance(fields["b"], AST):
                 return None
             a = self.unparse(fields["a"])
             if a is None:
                 return None
             b = self.unparse(fields["b"])
             if b is None:
                 return None
             return Union(a, b)
         elif code.get_type_name() == "Difference":
             if not isinstance(fields["a"], AST):
                 return None
             if not isinstance(fields["b"], AST):
                 return None
             a = self.unparse(fields["a"])
             if a is None:
                 return None
             b = self.unparse(fields["b"])
             if b is None:
                 return None
             return Difference(a, b)
         return None
     elif isinstance(code, Leaf):
         return cast(csgAST, code.value)
     return None
コード例 #4
0
    def pretrain(self, output_dir):
        dataset = Dataset(4, 1, 2, 1, 45, seed=0)
        """
        """
        train_dataset = ListDataset([
            Environment(
                {"ground_truth": Circle(1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 2)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rotation(45, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Translation(1, 1, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Circle(1), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Union(Rectangle(1, 2), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Rectangle(1, 1), Circle(1))},
                set(["ground_truth"]),
            ),
        ])

        with tempfile.TemporaryDirectory() as tmpdir:
            interpreter = self.interpreter()
            train_dataset = data_transform(
                train_dataset,
                Apply(
                    module=AddTestCases(interpreter),
                    in_keys=["ground_truth"],
                    out_key="test_cases",
                    is_out_supervision=False,
                ))
            encoder = self.prepare_encoder(dataset, Parser())

            collate = Collate(
                test_case_tensor=CollateOptions(False, 0, 0),
                variables_tensor=CollateOptions(True, 0, 0),
                previous_actions=CollateOptions(True, 0, -1),
                hidden_state=CollateOptions(False, 0, 0),
                state=CollateOptions(False, 0, 0),
                ground_truth_actions=CollateOptions(True, 0, -1)
            )
            collate_fn = Sequence(OrderedDict([
                ("to_episode", Map(self.to_episode(encoder,
                                                   interpreter))),
                ("flatten", Flatten()),
                ("transform", Map(self.transform(
                    encoder, interpreter, Parser()))),
                ("collate", collate.collate)
            ]))

            model = self.prepare_model(encoder)
            optimizer = self.prepare_optimizer(model)
            train_supervised(
                tmpdir, output_dir,
                train_dataset, model, optimizer,
                torch.nn.Sequential(OrderedDict([
                    ("loss",
                     Apply(
                         module=Loss(
                             reduction="sum",
                         ),
                         in_keys=[
                             "rule_probs",
                             "token_probs",
                             "reference_probs",
                             "ground_truth_actions",
                         ],
                         out_key="action_sequence_loss",
                     )),
                    ("normalize",  # divided by batch_size
                     Apply(
                         [("action_sequence_loss", "lhs")],
                         "loss",
                         mlprogram.nn.Function(Div()),
                         constants={"rhs": 1})),
                    ("pick",
                     mlprogram.nn.Function(
                         Pick("loss")))
                ])),
                None, "score",
                collate_fn,
                1, Epoch(100), evaluation_interval=Epoch(10),
                snapshot_interval=Epoch(100)
            )
        return encoder, train_dataset
コード例 #5
0
ファイル: test_interpreter.py プロジェクト: nashid/mlprogram
 def test_multiple_inputs(self):
     interpreter = Interpreter(1, 1, 1, False)
     results = interpreter.eval(Circle(1), [None, None])
     assert len(results) == 2
コード例 #6
0
ファイル: test_interpreter.py プロジェクト: nashid/mlprogram
 def test_circle(self):
     interpreter = Interpreter(1, 1, 1, False)
     assert "#\n" == show(interpreter.eval(Circle(1), [None])[0])
コード例 #7
0
 def test_unparse_circle(self, parser):
     assert Circle(1) == parser.unparse(parser.parse(Circle(1)))
コード例 #8
0
 def test_parse_circle(self, parser):
     assert Node("Circle",
                 [Field("r", "size", Leaf("size", 1))]) == parser.parse(
                     Circle(1))
コード例 #9
0
    def test_unexpand(self, expander):
        assert expander.unexpand([Circle(1)]) == Circle(1)
        assert expander.unexpand([Rotation(1, Circle(1))]) == \
            Rotation(1, Circle(1))

        assert expander.unexpand(
            [Circle(1), Difference(Circle(1), Reference(0))]
        ) == Difference(Circle(1), Circle(1))

        assert expander.unexpand(
            [Circle(1), Rotation(1, Reference(0)),
             Difference(Circle(1), Reference(1))]
        ) == Difference(Circle(1), Rotation(1, Circle(1)))
コード例 #10
0
    def test_expand(self, expander):
        assert expander.expand(Circle(1)) == [Circle(1)]
        assert expander.expand(Rotation(1, Circle(1))) == \
            [Circle(1), Rotation(1, Reference(0))]

        assert expander.expand(Difference(Circle(1),
                                          Circle(1))) == \
            [Circle(1), Circle(1), Difference(Reference(0), Reference(1))]

        assert expander.expand(Difference(Circle(1), Rotation(1, Circle(1)))) == \
            [Circle(1), Circle(1), Rotation(1, Reference(1)),
             Difference(Reference(0), Reference(2))]