コード例 #1
0
    def reinforce(self, train_dataset, encoder, output_dir):
        interpreter = self.interpreter()

        collate = Collate(
            test_case_tensor=CollateOptions(False, 0, 0),
            variables_tensor=CollateOptions(True, 0, 0),
            previous_actions=CollateOptions(True, 0, -1),
            hidden_state=CollateOptions(False, 0, 0),
            state=CollateOptions(False, 0, 0),
            ground_truth_actions=CollateOptions(True, 0, -1),
            reward=CollateOptions(False, 0, 0)
        )
        collate_fn = Sequence(OrderedDict([
            ("to_episode", Map(self.to_episode(encoder,
                                               interpreter))),
            ("flatten", Flatten()),
            ("transform", Map(self.transform(
                encoder, interpreter, Parser()))),
            ("collate", collate.collate)
        ]))

        model = self.prepare_model(encoder)
        optimizer = self.prepare_optimizer(model)
        train_REINFORCE(
            os.path.join(output_dir, "pretrain"), os.path.join(output_dir, "RL"),
            train_dataset,
            self.prepare_synthesizer(model, encoder, interpreter),
            model, optimizer,
            torch.nn.Sequential(OrderedDict([
                ("policy",
                 torch.nn.Sequential(OrderedDict([
                     ("loss",
                      Apply(
                          module=mlprogram.nn.action_sequence.Loss(
                              reduction="none"
                          ),
                          in_keys=[
                              "rule_probs",
                              "token_probs",
                              "reference_probs",
                              "ground_truth_actions",
                          ],
                          out_key="action_sequence_loss",
                      )),
                     ("weight_by_reward",
                      Apply(
                          [("reward", "lhs"),
                           ("action_sequence_loss", "rhs")],
                          "action_sequence_loss",
                          mlprogram.nn.Function(Mul())))
                 ]))),
                ("value",
                 torch.nn.Sequential(OrderedDict([
                     ("reshape_reward",
                      Apply(
                          [("reward", "x")],
                          "value_loss_target",
                          Reshape([-1, 1]))),
                     ("BCE",
                      Apply(
                          [("value", "input"),
                           ("value_loss_target", "target")],
                          "value_loss",
                          torch.nn.BCELoss(reduction='sum'))),
                     ("reweight",
                      Apply(
                          [("value_loss", "lhs")],
                          "value_loss",
                          mlprogram.nn.Function(Mul()),
                          constants={"rhs": 1e-2})),
                 ]))),
                ("aggregate",
                 Apply(
                     ["action_sequence_loss", "value_loss"],
                     "loss",
                     AggregatedLoss())),
                ("normalize",
                 Apply(
                     [("loss", "lhs")],
                     "loss",
                     mlprogram.nn.Function(Div()),
                     constants={"rhs": 1})),
                ("pick",
                 mlprogram.nn.Function(
                     Pick("loss")))
            ])),
            EvaluateSynthesizer(
                train_dataset,
                self.prepare_synthesizer(model, encoder, interpreter,
                                         rollout=False),
                {}, top_n=[]),
            "generation_rate",
            metrics.use_environment(
                metric=metrics.TestCaseResult(
                    interpreter=interpreter,
                    metric=metrics.use_environment(
                        metric=metrics.Iou(),
                        in_keys=["actual", "expected"],
                        value_key="actual",
                    )
                ),
                in_keys=["test_cases", "actual"],
                value_key="actual",
                transform=Threshold(threshold=0.9, dtype="float"),
            ),
            collate_fn,
            1, 1,
            Epoch(10), evaluation_interval=Epoch(10),
            snapshot_interval=Epoch(10),
            use_pretrained_model=True,
            use_pretrained_optimizer=True,
            threshold=1.0)
コード例 #2
0
    def prepare_synthesizer(self, model, encoder, interpreter, rollout=True):
        collate = Collate(
            test_case_tensor=CollateOptions(False, 0, 0),
            input_feature=CollateOptions(False, 0, 0),
            test_case_feature=CollateOptions(False, 0, 0),
            reference_features=CollateOptions(True, 0, 0),
            variables_tensor=CollateOptions(True, 0, 0),
            previous_actions=CollateOptions(True, 0, -1),
            hidden_state=CollateOptions(False, 0, 0),
            state=CollateOptions(False, 0, 0),
            ground_truth_actions=CollateOptions(True, 0, -1)
        )
        subsampler = ActionSequenceSampler(
            encoder, IsSubtype(),
            Sequence(OrderedDict([
                ("tinput",
                 Apply(
                     module=TransformInputs(),
                     in_keys=["test_cases"],
                     out_key="test_case_tensor",
                 )),
                ("tvariable",
                 Apply(
                     module=TransformVariables(),
                     in_keys=["variables", "test_case_tensor"],
                     out_key="variables_tensor"
                 )),
            ])),
            Compose(OrderedDict([
                ("add_previous_actions",
                 Apply(
                    module=AddPreviousActions(encoder, n_dependent=1),
                    in_keys=["action_sequence", "reference"],
                    out_key="previous_actions",
                    constants={"train": False},
                    )),
                ("add_state", AddState("state")),
                ("add_hidden_state", AddState("hidden_state"))
            ])),
            collate, model,
            rng=np.random.RandomState(0))
        subsampler = mlprogram.samplers.transform(
            subsampler,
            Parser().unparse
        )
        subsynthesizer = SMC(
            5, 1,
            subsampler,
            max_try_num=1,
            to_key=Pick("action_sequence"),
            rng=np.random.RandomState(0)
        )

        sampler = SequentialProgramSampler(
            subsynthesizer,
            Apply(
                module=TransformInputs(),
                in_keys=["test_cases"],
                out_key="test_case_tensor",
            ),
            collate,
            model.encode_input,
            interpreter=interpreter,
            expander=Expander(),
            rng=np.random.RandomState(0))
        if rollout:
            sampler = FilteredSampler(
                sampler,
                metrics.use_environment(
                    metric=metrics.TestCaseResult(
                        interpreter,
                        metric=metrics.use_environment(
                            metric=metrics.Iou(),
                            in_keys=["actual", "expected"],
                            value_key="actual",
                        )
                    ),
                    in_keys=["test_cases", "actual"],
                    value_key="actual"
                ),
                1.0
            )
            return SMC(3, 1, sampler, rng=np.random.RandomState(0),
                       to_key=Pick("interpreter_state"), max_try_num=1)
        else:
            sampler = SamplerWithValueNetwork(
                sampler,
                Sequence(OrderedDict([
                    ("tinput",
                     Apply(
                         module=TransformInputs(),
                         in_keys=["test_cases"],
                         out_key="test_case_tensor",
                     )),
                    ("tvariable",
                     Apply(
                         module=TransformVariables(),
                         in_keys=["variables", "test_case_tensor"],
                         out_key="variables_tensor"
                     )),
                ])),
                collate,
                torch.nn.Sequential(OrderedDict([
                    ("encoder", model.encoder),
                    ("value", model.value),
                    ("pick",
                     mlprogram.nn.Function(
                         Pick("value")))
                ])))

            synthesizer = SynthesizerWithTimeout(
                SMC(3, 1, sampler, rng=np.random.RandomState(0),
                    to_key=Pick("interpreter_state"),
                    max_try_num=1),
                1
            )
            return FilteredSynthesizer(
                synthesizer,
                metrics.use_environment(
                    metric=metrics.TestCaseResult(
                        interpreter,
                        metric=metrics.use_environment(
                            metric=metrics.Iou(),
                            in_keys=["actual", "expected"],
                            value_key="actual",
                        )
                    ),
                    in_keys=["test_cases", "actual"],
                    value_key="actual"
                ),
                1.0
            )
コード例 #3
0
    def pretrain(self, output_dir):
        dataset = Dataset(4, 1, 2, 1, 45, seed=0)
        """
        """
        train_dataset = ListDataset([
            Environment(
                {"ground_truth": Circle(1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 2)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rotation(45, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Translation(1, 1, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Circle(1), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Union(Rectangle(1, 2), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Rectangle(1, 1), Circle(1))},
                set(["ground_truth"]),
            ),
        ])

        interpreter = self.interpreter()
        train_dataset = data_transform(
            train_dataset,
            Apply(
                module=AddTestCases(interpreter),
                in_keys=["ground_truth"],
                out_key="test_cases",
                is_out_supervision=False,
            )
        )
        encoder = self.prepare_encoder(dataset, Parser())

        collate = Collate(
            test_case_tensor=CollateOptions(False, 0, 0),
            variables_tensor=CollateOptions(True, 0, 0),
            previous_actions=CollateOptions(True, 0, -1),
            hidden_state=CollateOptions(False, 0, 0),
            state=CollateOptions(False, 0, 0),
            ground_truth_actions=CollateOptions(True, 0, -1)
        )
        collate_fn = Sequence(OrderedDict([
            ("to_episode", Map(self.to_episode(encoder,
                                               interpreter))),
            ("flatten", Flatten()),
            ("transform", Map(self.transform(
                encoder, interpreter, Parser()))),
            ("collate", collate.collate)
        ]))

        model = self.prepare_model(encoder)
        optimizer = self.prepare_optimizer(model)
        train_supervised(
            os.path.join(output_dir, "pretrain"),
            train_dataset, model, optimizer,
            torch.nn.Sequential(OrderedDict([
                ("loss",
                 Apply(
                     module=Loss(
                         reduction="sum",
                     ),
                     in_keys=[
                         "rule_probs",
                         "token_probs",
                         "reference_probs",
                         "ground_truth_actions",
                     ],
                     out_key="action_sequence_loss",
                 )),
                ("normalize",  # divided by batch_size
                 Apply(
                     [("action_sequence_loss", "lhs")],
                     "loss",
                     mlprogram.nn.Function(Div()),
                     constants={"rhs": 1})),
                ("pick",
                 mlprogram.nn.Function(
                     Pick("loss")))
            ])),
            None, "score",
            collate_fn,
            1, Epoch(100), evaluation_interval=Epoch(10),
            snapshot_interval=Epoch(100)
        )
        return encoder, train_dataset
コード例 #4
0
def parser():
    return Parser()