Ejemplo n.º 1
0
    def train(self, output_dir):
        with tempfile.TemporaryDirectory() as tmpdir:
            loss_fn = nn.Sequential(
                OrderedDict([
                    ("loss",
                     Apply(
                         module=Loss(),
                         in_keys=[
                             "rule_probs",
                             "token_probs",
                             "reference_probs",
                             "ground_truth_actions",
                         ],
                         out_key="action_sequence_loss",
                     )),
                    ("pick",
                     mlprogram.nn.Function(Pick("action_sequence_loss")))
                ]))
            collate = Collate(word_nl_query=CollateOptions(True, 0, -1),
                              nl_query_features=CollateOptions(True, 0, -1),
                              reference_features=CollateOptions(True, 0, -1),
                              actions=CollateOptions(True, 0, -1),
                              previous_actions=CollateOptions(True, 0, -1),
                              previous_action_rules=CollateOptions(
                                  True, 0, -1),
                              history=CollateOptions(False, 1, 0),
                              hidden_state=CollateOptions(False, 0, 0),
                              state=CollateOptions(False, 0, 0),
                              ground_truth_actions=CollateOptions(True, 0,
                                                                  -1)).collate

            qencoder, aencoder = \
                self.prepare_encoder(train_dataset, Parser())
            transform = Map(self.transform_cls(qencoder, aencoder, Parser()))
            model = self.prepare_model(qencoder, aencoder)
            optimizer = self.prepare_optimizer(model)
            train_supervised(tmpdir,
                             output_dir,
                             train_dataset,
                             model,
                             optimizer,
                             loss_fn,
                             EvaluateSynthesizer(test_dataset,
                                                 self.prepare_synthesizer(
                                                     model, qencoder,
                                                     aencoder),
                                                 {"accuracy": Accuracy()},
                                                 top_n=[5]),
                             "accuracy@5",
                             lambda x: collate(transform(x)),
                             1,
                             Epoch(100),
                             evaluation_interval=Epoch(100),
                             snapshot_interval=Epoch(100),
                             threshold=1.0)
        return qencoder, aencoder
Ejemplo n.º 2
0
 def test_exception(self):
     f = Map(raise_exception)
     assert f([2]) == [None]
Ejemplo n.º 3
0
 def test_multiprocessing(self):
     f = Map(add1, 1)
     assert [3] == f([2])
Ejemplo n.º 4
0
 def test_happy_path(self):
     f = Map(lambda x: x + 1)
     assert [3] == f([2])
Ejemplo n.º 5
0
    def reinforce(self, train_dataset, encoder, output_dir):
        with tempfile.TemporaryDirectory() as tmpdir:
            interpreter = self.interpreter()

            collate = Collate(
                test_case_tensor=CollateOptions(False, 0, 0),
                variables_tensor=CollateOptions(True, 0, 0),
                previous_actions=CollateOptions(True, 0, -1),
                hidden_state=CollateOptions(False, 0, 0),
                state=CollateOptions(False, 0, 0),
                ground_truth_actions=CollateOptions(True, 0, -1),
                reward=CollateOptions(False, 0, 0)
            )
            collate_fn = Sequence(OrderedDict([
                ("to_episode", Map(self.to_episode(encoder,
                                                   interpreter))),
                ("flatten", Flatten()),
                ("transform", Map(self.transform(
                    encoder, interpreter, Parser()))),
                ("collate", collate.collate)
            ]))

            model = self.prepare_model(encoder)
            optimizer = self.prepare_optimizer(model)
            train_REINFORCE(
                output_dir, tmpdir, output_dir,
                train_dataset,
                self.prepare_synthesizer(model, encoder, interpreter),
                model, optimizer,
                torch.nn.Sequential(OrderedDict([
                    ("policy",
                     torch.nn.Sequential(OrderedDict([
                         ("loss",
                          Apply(
                              module=mlprogram.nn.action_sequence.Loss(
                                  reduction="none"
                              ),
                              in_keys=[
                                  "rule_probs",
                                  "token_probs",
                                  "reference_probs",
                                  "ground_truth_actions",
                              ],
                              out_key="action_sequence_loss",
                          )),
                         ("weight_by_reward",
                             Apply(
                                 [("reward", "lhs"),
                                  ("action_sequence_loss", "rhs")],
                                 "action_sequence_loss",
                                 mlprogram.nn.Function(Mul())))
                     ]))),
                    ("value",
                     torch.nn.Sequential(OrderedDict([
                         ("reshape_reward",
                             Apply(
                                 [("reward", "x")],
                                 "value_loss_target",
                                 Reshape([-1, 1]))),
                         ("BCE",
                             Apply(
                                 [("value", "input"),
                                  ("value_loss_target", "target")],
                                 "value_loss",
                                 torch.nn.BCELoss(reduction='sum'))),
                         ("reweight",
                             Apply(
                                 [("value_loss", "lhs")],
                                 "value_loss",
                                 mlprogram.nn.Function(Mul()),
                                 constants={"rhs": 1e-2})),
                     ]))),
                    ("aggregate",
                     Apply(
                         ["action_sequence_loss", "value_loss"],
                         "loss",
                         AggregatedLoss())),
                    ("normalize",
                     Apply(
                         [("loss", "lhs")],
                         "loss",
                         mlprogram.nn.Function(Div()),
                         constants={"rhs": 1})),
                    ("pick",
                     mlprogram.nn.Function(
                         Pick("loss")))
                ])),
                EvaluateSynthesizer(
                    train_dataset,
                    self.prepare_synthesizer(model, encoder, interpreter,
                                             rollout=False),
                    {}, top_n=[]),
                "generation_rate",
                metrics.use_environment(
                    metric=metrics.TestCaseResult(
                        interpreter=interpreter,
                        metric=metrics.use_environment(
                            metric=metrics.Iou(),
                            in_keys=["actual", "expected"],
                            value_key="actual",
                        )
                    ),
                    in_keys=["test_cases", "actual"],
                    value_key="actual",
                    transform=Threshold(threshold=0.9, dtype="float"),
                ),
                collate_fn,
                1, 1,
                Epoch(10), evaluation_interval=Epoch(10),
                snapshot_interval=Epoch(10),
                use_pretrained_model=True,
                use_pretrained_optimizer=True,
                threshold=1.0)
Ejemplo n.º 6
0
    def pretrain(self, output_dir):
        dataset = Dataset(4, 1, 2, 1, 45, seed=0)
        """
        """
        train_dataset = ListDataset([
            Environment(
                {"ground_truth": Circle(1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 2)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rotation(45, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Translation(1, 1, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Circle(1), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Union(Rectangle(1, 2), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Rectangle(1, 1), Circle(1))},
                set(["ground_truth"]),
            ),
        ])

        with tempfile.TemporaryDirectory() as tmpdir:
            interpreter = self.interpreter()
            train_dataset = data_transform(
                train_dataset,
                Apply(
                    module=AddTestCases(interpreter),
                    in_keys=["ground_truth"],
                    out_key="test_cases",
                    is_out_supervision=False,
                ))
            encoder = self.prepare_encoder(dataset, Parser())

            collate = Collate(
                test_case_tensor=CollateOptions(False, 0, 0),
                variables_tensor=CollateOptions(True, 0, 0),
                previous_actions=CollateOptions(True, 0, -1),
                hidden_state=CollateOptions(False, 0, 0),
                state=CollateOptions(False, 0, 0),
                ground_truth_actions=CollateOptions(True, 0, -1)
            )
            collate_fn = Sequence(OrderedDict([
                ("to_episode", Map(self.to_episode(encoder,
                                                   interpreter))),
                ("flatten", Flatten()),
                ("transform", Map(self.transform(
                    encoder, interpreter, Parser()))),
                ("collate", collate.collate)
            ]))

            model = self.prepare_model(encoder)
            optimizer = self.prepare_optimizer(model)
            train_supervised(
                tmpdir, output_dir,
                train_dataset, model, optimizer,
                torch.nn.Sequential(OrderedDict([
                    ("loss",
                     Apply(
                         module=Loss(
                             reduction="sum",
                         ),
                         in_keys=[
                             "rule_probs",
                             "token_probs",
                             "reference_probs",
                             "ground_truth_actions",
                         ],
                         out_key="action_sequence_loss",
                     )),
                    ("normalize",  # divided by batch_size
                     Apply(
                         [("action_sequence_loss", "lhs")],
                         "loss",
                         mlprogram.nn.Function(Div()),
                         constants={"rhs": 1})),
                    ("pick",
                     mlprogram.nn.Function(
                         Pick("loss")))
                ])),
                None, "score",
                collate_fn,
                1, Epoch(100), evaluation_interval=Epoch(10),
                snapshot_interval=Epoch(100)
            )
        return encoder, train_dataset