예제 #1
0
 def transform_cls(self, qencoder, cencoder, aencoder, parser):
     return Sequence(
         OrderedDict([("extract_reference",
                       Apply(module=mlprogram.nn.Function(tokenize),
                             in_keys=[["text_query", "str"]],
                             out_key="reference")),
                      ("encode_word_query",
                       Apply(module=EncodeWordQuery(qencoder),
                             in_keys=["reference"],
                             out_key="word_nl_query")),
                      ("encode_char",
                       Apply(module=EncodeCharacterQuery(cencoder, 10),
                             in_keys=["reference"],
                             out_key="char_nl_query")),
                      ("f2",
                       Apply(
                           module=GroundTruthToActionSequence(parser),
                           in_keys=["ground_truth"],
                           out_key="action_sequence",
                       )),
                      ("add_previous_action",
                       Apply(
                           module=AddPreviousActions(aencoder),
                           in_keys=["action_sequence", "reference"],
                           constants={"train": True},
                           out_key="previous_actions",
                       )),
                      ("add_previous_action_rule",
                       Apply(
                           module=AddPreviousActionRules(aencoder, 4),
                           in_keys=["action_sequence", "reference"],
                           constants={"train": True},
                           out_key="previous_action_rules",
                       )),
                      ("add_tree",
                       Apply(
                           module=AddActionSequenceAsTree(aencoder),
                           in_keys=["action_sequence", "reference"],
                           constants={"train": True},
                           out_key=["adjacency_matrix", "depthes"],
                       )),
                      ("add_query",
                       Apply(
                           module=AddQueryForTreeGenDecoder(aencoder, 4),
                           in_keys=["action_sequence", "reference"],
                           constants={"train": True},
                           out_key="action_queries",
                       )),
                      ("f4",
                       Apply(
                           module=EncodeActionSequence(aencoder),
                           in_keys=["action_sequence", "reference"],
                           out_key="ground_truth_actions",
                       ))]))
예제 #2
0
 def test_constants(self):
     apply = Apply([],
                   "out",
                   MockModule(1),
                   constants={"x": torch.arange(3).reshape(-1, 1)})
     output = apply(Environment())
     assert np.array_equal([[1], [2], [3]], output["out"].detach().numpy())
예제 #3
0
 def test_multiple_inputs(self):
     apply = Apply(["x", "y"], "out", MockModule(1))
     output = apply(
         Environment({
             "x": torch.arange(3).reshape(-1, 1),
             "y": 10
         }))
     assert np.array_equal([[11], [12], [13]],
                           output["out"].detach().numpy())
예제 #4
0
    def train(self, output_dir):
        with tempfile.TemporaryDirectory() as tmpdir:
            loss_fn = nn.Sequential(
                OrderedDict([
                    ("loss",
                     Apply(
                         module=Loss(),
                         in_keys=[
                             "rule_probs",
                             "token_probs",
                             "reference_probs",
                             "ground_truth_actions",
                         ],
                         out_key="action_sequence_loss",
                     )),
                    ("pick",
                     mlprogram.nn.Function(Pick("action_sequence_loss")))
                ]))
            collate = Collate(word_nl_query=CollateOptions(True, 0, -1),
                              nl_query_features=CollateOptions(True, 0, -1),
                              reference_features=CollateOptions(True, 0, -1),
                              actions=CollateOptions(True, 0, -1),
                              previous_actions=CollateOptions(True, 0, -1),
                              previous_action_rules=CollateOptions(
                                  True, 0, -1),
                              history=CollateOptions(False, 1, 0),
                              hidden_state=CollateOptions(False, 0, 0),
                              state=CollateOptions(False, 0, 0),
                              ground_truth_actions=CollateOptions(True, 0,
                                                                  -1)).collate

            qencoder, aencoder = \
                self.prepare_encoder(train_dataset, Parser())
            transform = Map(self.transform_cls(qencoder, aencoder, Parser()))
            model = self.prepare_model(qencoder, aencoder)
            optimizer = self.prepare_optimizer(model)
            train_supervised(tmpdir,
                             output_dir,
                             train_dataset,
                             model,
                             optimizer,
                             loss_fn,
                             EvaluateSynthesizer(test_dataset,
                                                 self.prepare_synthesizer(
                                                     model, qencoder,
                                                     aencoder),
                                                 {"accuracy": Accuracy()},
                                                 top_n=[5]),
                             "accuracy@5",
                             lambda x: collate(transform(x)),
                             1,
                             Epoch(100),
                             evaluation_interval=Epoch(100),
                             snapshot_interval=Epoch(100),
                             threshold=1.0)
        return qencoder, aencoder
예제 #5
0
 def prepare_synthesizer(self, model, qencoder, aencoder):
     transform_input = Compose(
         OrderedDict([("extract_reference",
                       Apply(module=mlprogram.nn.Function(tokenize),
                             in_keys=[["text_query", "str"]],
                             out_key="reference")),
                      ("encode_query",
                       Apply(module=EncodeWordQuery(qencoder),
                             in_keys=["reference"],
                             out_key="word_nl_query"))]))
     transform_action_sequence = Compose(
         OrderedDict([("add_previous_action",
                       Apply(
                           module=AddPreviousActions(aencoder,
                                                     n_dependent=1),
                           in_keys=["action_sequence", "reference"],
                           constants={"train": False},
                           out_key="previous_actions",
                       )),
                      ("add_action",
                       Apply(
                           module=AddActions(aencoder, n_dependent=1),
                           in_keys=["action_sequence", "reference"],
                           constants={"train": False},
                           out_key="actions",
                       )), ("add_state", AddState("state")),
                      ("add_hidden_state", AddState("hidden_state")),
                      ("add_history", AddState("history"))]))
     collate = Collate(word_nl_query=CollateOptions(True, 0, -1),
                       nl_query_features=CollateOptions(True, 0, -1),
                       reference_features=CollateOptions(True, 0, -1),
                       actions=CollateOptions(True, 0, -1),
                       previous_actions=CollateOptions(True, 0, -1),
                       previous_action_rules=CollateOptions(True, 0, -1),
                       history=CollateOptions(False, 1, 0),
                       hidden_state=CollateOptions(False, 0, 0),
                       state=CollateOptions(False, 0, 0),
                       ground_truth_actions=CollateOptions(True, 0, -1))
     return BeamSearch(
         5, 20,
         ActionSequenceSampler(aencoder, is_subtype, transform_input,
                               transform_action_sequence, collate, model))
예제 #6
0
 def transform(self, encoder, interpreter, parser):
     tcode = Apply(
         module=GroundTruthToActionSequence(parser),
         in_keys=["ground_truth"],
         out_key="action_sequence"
     )
     aaction = Apply(
         module=AddPreviousActions(encoder, n_dependent=1),
         in_keys=["action_sequence", "reference"],
         constants={"train": True},
         out_key="previous_actions",
     )
     tgt = Apply(
         module=EncodeActionSequence(encoder),
         in_keys=["action_sequence", "reference"],
         out_key="ground_truth_actions",
     )
     return Sequence(
         OrderedDict([
             ("tinput",
              Apply(
                  module=TransformInputs(),
                  in_keys=["test_cases"],
                  out_key="test_case_tensor",
              )),
             ("tvariable",
              Apply(
                  module=TransformVariables(),
                  in_keys=["variables", "test_case_tensor"],
                  out_key="variables_tensor"
              )),
             ("tcode", tcode),
             ("aaction", aaction),
             ("add_state", AddState("state")),
             ("add_hidden_state", AddState("hidden_state")),
             ("tgt", tgt)
         ])
     )
예제 #7
0
    def test_propagate_supervision(self):
        apply = Apply(["x", "y"], "out", MockModule(1))
        output = apply(
            Environment({
                "x": torch.arange(3).reshape(-1, 1),
                "y": 10
            }))
        assert not output.is_supervision("out")

        output = apply(
            Environment({
                "x": torch.arange(3).reshape(-1, 1),
                "y": 10
            }, set(["x"])))
        assert output.is_supervision("out")
예제 #8
0
파일: metric.py 프로젝트: nashid/mlprogram
 def __init__(self,
              metric: Callable,
              in_keys,
              value_key: str,
              transform: Optional[Callable[[float], float]] = None):
     super().__init__()
     self.value_key = value_key
     self.metric = Sequence(
         OrderedDict([("metric",
                       Apply(
                           module=Function(metric),
                           in_keys=in_keys,
                           out_key="metric",
                       )), ("pick", Pick("metric"))]))
     self.transform = transform
예제 #9
0
 def prepare_model(self, encoder: ActionSequenceEncoder):
     return torch.nn.Sequential(OrderedDict([
         ("encode_input",
          Apply([("test_case_tensor", "x")],
                "test_case_feature",
                CNN2d(1, 16, 32, 2, 2, 2))),
         ("encoder",
          Apply(
              module=Encoder(CNN2d(2, 16, 32, 2, 2, 2)),
              in_keys=["test_case_tensor",
                       "variables_tensor", "test_case_feature"],
              out_key=["reference_features", "input_feature"]
          )),
         ("decoder",
          torch.nn.Sequential(OrderedDict([
              ("action_embedding",
               Apply(
                   module=a_s.PreviousActionsEmbedding(
                       n_rule=encoder._rule_encoder.vocab_size,
                       n_token=encoder._token_encoder.vocab_size,
                       embedding_size=256,
                   ),
                   in_keys=["previous_actions"],
                   out_key="action_features"
               )),
              ("decoder",
               Apply(
                   module=a_s.LSTMDecoder(
                       inject_input=a_s.CatInput(),
                       input_feature_size=2 * 16 * 8 * 8,
                       action_feature_size=256,
                       output_feature_size=512,
                       dropout=0.0
                   ),
                   in_keys=["input_feature", "action_features", "hidden_state",
                            "state"],
                   out_key=["action_features", "hidden_state", "state"]
               )),
              ("predictor",
               Apply(
                   module=a_s.Predictor(512, 16 * 8 * 8,
                                        encoder._rule_encoder.vocab_size,
                                        encoder._token_encoder.vocab_size,
                                        512),
                   in_keys=["action_features", "reference_features"],
                   out_key=["rule_probs", "token_probs", "reference_probs"]))
          ]))),
         ("value",
          Apply([("input_feature", "x")], "value",
                MLP(16 * 8 * 8 * 2, 1, 512, 2,
                    activation=torch.nn.Sigmoid()),
                ))
     ]))
예제 #10
0
 def transform_cls(self, qencoder, aencoder, parser):
     return Sequence(
         OrderedDict([("extract_reference",
                       Apply(module=mlprogram.nn.Function(tokenize),
                             in_keys=[["text_query", "str"]],
                             out_key="reference")),
                      ("encode_word_query",
                       Apply(module=EncodeWordQuery(qencoder),
                             in_keys=["reference"],
                             out_key="word_nl_query")),
                      ("f2",
                       Apply(
                           module=GroundTruthToActionSequence(parser),
                           in_keys=["ground_truth"],
                           out_key="action_sequence",
                       )),
                      ("add_previous_action",
                       Apply(
                           module=AddPreviousActions(aencoder,
                                                     n_dependent=1),
                           in_keys=["action_sequence", "reference"],
                           constants={"train": True},
                           out_key="previous_actions",
                       )),
                      ("add_action",
                       Apply(
                           module=AddActions(aencoder, n_dependent=1),
                           in_keys=["action_sequence", "reference"],
                           constants={"train": True},
                           out_key="actions",
                       )), ("add_state", AddState("state")),
                      ("add_hidden_state", AddState("hidden_state")),
                      ("add_history", AddState("history")),
                      ("f4",
                       Apply(
                           module=EncodeActionSequence(aencoder),
                           in_keys=["action_sequence", "reference"],
                           out_key="ground_truth_actions",
                       ))]))
예제 #11
0
 def test_parameters(self):
     apply = Apply(["x"], "out", MockModule(1))
     assert set(["module.p"]) == \
         dict(apply.named_parameters()).keys()
예제 #12
0
 def prepare_model(self, qencoder, cencoder, aencoder):
     rule_num = aencoder._rule_encoder.vocab_size
     token_num = aencoder._token_encoder.vocab_size
     node_type_num = aencoder._node_type_encoder.vocab_size
     token_num = aencoder._token_encoder.vocab_size
     return torch.nn.Sequential(
         OrderedDict([
             ("encoder",
              torch.nn.Sequential(
                  OrderedDict([
                      ("embedding",
                       Apply(module=treegen.NlEmbedding(
                           qencoder.vocab_size, cencoder.vocab_size, 10,
                           256, 256),
                             in_keys=["word_nl_query", "char_nl_query"],
                             out_key=["word_nl_feature",
                                      "char_nl_feature"])),
                      ("encoder",
                       Apply(module=treegen.Encoder(256, 256, 1, 0.0, 5),
                             in_keys=["word_nl_feature", "char_nl_feature"],
                             out_key="reference_features"))
                  ]))),
             ("decoder",
              torch.nn.Sequential(
                  OrderedDict([
                      ("query_embedding",
                       Apply(module=treegen.QueryEmbedding(
                           n_rule=rule_num,
                           max_depth=4,
                           embedding_size=256,
                       ),
                             in_keys=["action_queries"],
                             out_key="action_query_features")),
                      ("action_embedding",
                       Apply(module=treegen.ActionEmbedding(
                           n_rule=rule_num,
                           n_token=token_num,
                           n_node_type=node_type_num,
                           max_arity=4,
                           rule_embedding_size=256,
                           embedding_size=256,
                       ),
                             in_keys=[
                                 "previous_actions", "previous_action_rules"
                             ],
                             out_key=[
                                 "action_features", "action_rule_features"
                             ])),
                      ("decoder",
                       Apply(module=treegen.Decoder(
                           rule_embedding_size=256,
                           encoder_hidden_size=256,
                           decoder_hidden_size=1024,
                           out_size=256,
                           tree_conv_kernel_size=3,
                           n_head=1,
                           dropout=0.0,
                           n_encoder_block=5,
                           n_decoder_block=5,
                       ),
                             in_keys=[[
                                 "reference_features", "nl_query_features"
                             ], "action_query_features", "action_features",
                                      "action_rule_features", "depthes",
                                      "adjacency_matrix"],
                             out_key="action_features")),
                      ("predictor",
                       Apply(module=Predictor(256, 256, rule_num, token_num,
                                              256),
                             in_keys=[
                                 "reference_features", "action_features"
                             ],
                             out_key=[
                                 "rule_probs", "token_probs",
                                 "reference_probs"
                             ]))
                  ])))
         ]))
예제 #13
0
 def test_multiple_outputs(self):
     apply = Apply(["x"], ["out", "out2"], MockModule2(1))
     output = apply(Environment({"x": torch.arange(3).reshape(-1, 1)}))
     assert np.array_equal([[1], [2], [3]], output["out"].detach().numpy())
     assert np.array_equal([[0], [1], [2]], output["out2"].detach().numpy())
예제 #14
0
    def prepare_synthesizer(self, model, qencoder, cencoder, aencoder):
        transform_input = Compose(
            OrderedDict([("extract_reference",
                          Apply(module=mlprogram.nn.Function(tokenize),
                                in_keys=[["text_query", "str"]],
                                out_key="reference")),
                         ("encode_word_query",
                          Apply(module=EncodeWordQuery(qencoder),
                                in_keys=["reference"],
                                out_key="word_nl_query")),
                         ("encode_char",
                          Apply(module=EncodeCharacterQuery(cencoder, 10),
                                in_keys=["reference"],
                                out_key="char_nl_query"))]))
        transform_action_sequence = Compose(
            OrderedDict([("add_previous_action",
                          Apply(
                              module=AddPreviousActions(aencoder),
                              in_keys=["action_sequence", "reference"],
                              constants={"train": False},
                              out_key="previous_actions",
                          )),
                         ("add_previous_action_rule",
                          Apply(
                              module=AddPreviousActionRules(
                                  aencoder,
                                  4,
                              ),
                              in_keys=["action_sequence", "reference"],
                              constants={"train": False},
                              out_key="previous_action_rules",
                          )),
                         ("add_tree",
                          Apply(
                              module=AddActionSequenceAsTree(aencoder),
                              in_keys=["action_sequence", "reference"],
                              constants={"train": False},
                              out_key=["adjacency_matrix", "depthes"],
                          )),
                         ("add_query",
                          Apply(
                              module=AddQueryForTreeGenDecoder(aencoder, 4),
                              in_keys=["action_sequence", "reference"],
                              constants={"train": False},
                              out_key="action_queries",
                          ))]))

        collate = Collate(word_nl_query=CollateOptions(True, 0, -1),
                          char_nl_query=CollateOptions(True, 0, -1),
                          nl_query_features=CollateOptions(True, 0, -1),
                          reference_features=CollateOptions(True, 0, -1),
                          previous_actions=CollateOptions(True, 0, -1),
                          previous_action_rules=CollateOptions(True, 0, -1),
                          depthes=CollateOptions(False, 1, 0),
                          adjacency_matrix=CollateOptions(False, 0, 0),
                          action_queries=CollateOptions(True, 0, -1),
                          ground_truth_actions=CollateOptions(True, 0, -1))
        return BeamSearch(
            5, 20,
            ActionSequenceSampler(aencoder, is_subtype, transform_input,
                                  transform_action_sequence, collate, model))
예제 #15
0
    def reinforce(self, train_dataset, encoder, output_dir):
        with tempfile.TemporaryDirectory() as tmpdir:
            interpreter = self.interpreter()

            collate = Collate(
                test_case_tensor=CollateOptions(False, 0, 0),
                variables_tensor=CollateOptions(True, 0, 0),
                previous_actions=CollateOptions(True, 0, -1),
                hidden_state=CollateOptions(False, 0, 0),
                state=CollateOptions(False, 0, 0),
                ground_truth_actions=CollateOptions(True, 0, -1),
                reward=CollateOptions(False, 0, 0)
            )
            collate_fn = Sequence(OrderedDict([
                ("to_episode", Map(self.to_episode(encoder,
                                                   interpreter))),
                ("flatten", Flatten()),
                ("transform", Map(self.transform(
                    encoder, interpreter, Parser()))),
                ("collate", collate.collate)
            ]))

            model = self.prepare_model(encoder)
            optimizer = self.prepare_optimizer(model)
            train_REINFORCE(
                output_dir, tmpdir, output_dir,
                train_dataset,
                self.prepare_synthesizer(model, encoder, interpreter),
                model, optimizer,
                torch.nn.Sequential(OrderedDict([
                    ("policy",
                     torch.nn.Sequential(OrderedDict([
                         ("loss",
                          Apply(
                              module=mlprogram.nn.action_sequence.Loss(
                                  reduction="none"
                              ),
                              in_keys=[
                                  "rule_probs",
                                  "token_probs",
                                  "reference_probs",
                                  "ground_truth_actions",
                              ],
                              out_key="action_sequence_loss",
                          )),
                         ("weight_by_reward",
                             Apply(
                                 [("reward", "lhs"),
                                  ("action_sequence_loss", "rhs")],
                                 "action_sequence_loss",
                                 mlprogram.nn.Function(Mul())))
                     ]))),
                    ("value",
                     torch.nn.Sequential(OrderedDict([
                         ("reshape_reward",
                             Apply(
                                 [("reward", "x")],
                                 "value_loss_target",
                                 Reshape([-1, 1]))),
                         ("BCE",
                             Apply(
                                 [("value", "input"),
                                  ("value_loss_target", "target")],
                                 "value_loss",
                                 torch.nn.BCELoss(reduction='sum'))),
                         ("reweight",
                             Apply(
                                 [("value_loss", "lhs")],
                                 "value_loss",
                                 mlprogram.nn.Function(Mul()),
                                 constants={"rhs": 1e-2})),
                     ]))),
                    ("aggregate",
                     Apply(
                         ["action_sequence_loss", "value_loss"],
                         "loss",
                         AggregatedLoss())),
                    ("normalize",
                     Apply(
                         [("loss", "lhs")],
                         "loss",
                         mlprogram.nn.Function(Div()),
                         constants={"rhs": 1})),
                    ("pick",
                     mlprogram.nn.Function(
                         Pick("loss")))
                ])),
                EvaluateSynthesizer(
                    train_dataset,
                    self.prepare_synthesizer(model, encoder, interpreter,
                                             rollout=False),
                    {}, top_n=[]),
                "generation_rate",
                metrics.use_environment(
                    metric=metrics.TestCaseResult(
                        interpreter=interpreter,
                        metric=metrics.use_environment(
                            metric=metrics.Iou(),
                            in_keys=["actual", "expected"],
                            value_key="actual",
                        )
                    ),
                    in_keys=["test_cases", "actual"],
                    value_key="actual",
                    transform=Threshold(threshold=0.9, dtype="float"),
                ),
                collate_fn,
                1, 1,
                Epoch(10), evaluation_interval=Epoch(10),
                snapshot_interval=Epoch(10),
                use_pretrained_model=True,
                use_pretrained_optimizer=True,
                threshold=1.0)
예제 #16
0
    def pretrain(self, output_dir):
        dataset = Dataset(4, 1, 2, 1, 45, seed=0)
        """
        """
        train_dataset = ListDataset([
            Environment(
                {"ground_truth": Circle(1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 2)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rectangle(1, 1)},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Rotation(45, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Translation(1, 1, Rectangle(1, 1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Circle(1), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Union(Rectangle(1, 2), Circle(1))},
                set(["ground_truth"]),
            ),
            Environment(
                {"ground_truth": Difference(Rectangle(1, 1), Circle(1))},
                set(["ground_truth"]),
            ),
        ])

        with tempfile.TemporaryDirectory() as tmpdir:
            interpreter = self.interpreter()
            train_dataset = data_transform(
                train_dataset,
                Apply(
                    module=AddTestCases(interpreter),
                    in_keys=["ground_truth"],
                    out_key="test_cases",
                    is_out_supervision=False,
                ))
            encoder = self.prepare_encoder(dataset, Parser())

            collate = Collate(
                test_case_tensor=CollateOptions(False, 0, 0),
                variables_tensor=CollateOptions(True, 0, 0),
                previous_actions=CollateOptions(True, 0, -1),
                hidden_state=CollateOptions(False, 0, 0),
                state=CollateOptions(False, 0, 0),
                ground_truth_actions=CollateOptions(True, 0, -1)
            )
            collate_fn = Sequence(OrderedDict([
                ("to_episode", Map(self.to_episode(encoder,
                                                   interpreter))),
                ("flatten", Flatten()),
                ("transform", Map(self.transform(
                    encoder, interpreter, Parser()))),
                ("collate", collate.collate)
            ]))

            model = self.prepare_model(encoder)
            optimizer = self.prepare_optimizer(model)
            train_supervised(
                tmpdir, output_dir,
                train_dataset, model, optimizer,
                torch.nn.Sequential(OrderedDict([
                    ("loss",
                     Apply(
                         module=Loss(
                             reduction="sum",
                         ),
                         in_keys=[
                             "rule_probs",
                             "token_probs",
                             "reference_probs",
                             "ground_truth_actions",
                         ],
                         out_key="action_sequence_loss",
                     )),
                    ("normalize",  # divided by batch_size
                     Apply(
                         [("action_sequence_loss", "lhs")],
                         "loss",
                         mlprogram.nn.Function(Div()),
                         constants={"rhs": 1})),
                    ("pick",
                     mlprogram.nn.Function(
                         Pick("loss")))
                ])),
                None, "score",
                collate_fn,
                1, Epoch(100), evaluation_interval=Epoch(10),
                snapshot_interval=Epoch(100)
            )
        return encoder, train_dataset
예제 #17
0
 def test_rename_keys(self):
     apply = Apply([("in", "x")], "out", MockModule(1))
     output = apply(Environment({"in": torch.arange(3).reshape(-1, 1)}))
     assert np.array_equal([[1], [2], [3]], output["out"].detach().numpy())
예제 #18
0
    def prepare_synthesizer(self, model, encoder, interpreter, rollout=True):
        collate = Collate(
            test_case_tensor=CollateOptions(False, 0, 0),
            input_feature=CollateOptions(False, 0, 0),
            test_case_feature=CollateOptions(False, 0, 0),
            reference_features=CollateOptions(True, 0, 0),
            variables_tensor=CollateOptions(True, 0, 0),
            previous_actions=CollateOptions(True, 0, -1),
            hidden_state=CollateOptions(False, 0, 0),
            state=CollateOptions(False, 0, 0),
            ground_truth_actions=CollateOptions(True, 0, -1)
        )
        subsampler = ActionSequenceSampler(
            encoder, IsSubtype(),
            Sequence(OrderedDict([
                ("tinput",
                 Apply(
                     module=TransformInputs(),
                     in_keys=["test_cases"],
                     out_key="test_case_tensor",
                 )),
                ("tvariable",
                 Apply(
                     module=TransformVariables(),
                     in_keys=["variables", "test_case_tensor"],
                     out_key="variables_tensor"
                 )),
            ])),
            Compose(OrderedDict([
                ("add_previous_actions",
                 Apply(
                    module=AddPreviousActions(encoder, n_dependent=1),
                    in_keys=["action_sequence", "reference"],
                    out_key="previous_actions",
                    constants={"train": False},
                    )),
                ("add_state", AddState("state")),
                ("add_hidden_state", AddState("hidden_state"))
            ])),
            collate, model,
            rng=np.random.RandomState(0))
        subsampler = mlprogram.samplers.transform(
            subsampler,
            Parser().unparse
        )
        subsynthesizer = SMC(
            5, 1,
            subsampler,
            max_try_num=1,
            to_key=Pick("action_sequence"),
            rng=np.random.RandomState(0)
        )

        sampler = SequentialProgramSampler(
            subsynthesizer,
            Apply(
                module=TransformInputs(),
                in_keys=["test_cases"],
                out_key="test_case_tensor",
            ),
            collate,
            model.encode_input,
            interpreter=interpreter,
            expander=Expander(),
            rng=np.random.RandomState(0))
        if rollout:
            sampler = FilteredSampler(
                sampler,
                metrics.use_environment(
                    metric=metrics.TestCaseResult(
                        interpreter,
                        metric=metrics.use_environment(
                            metric=metrics.Iou(),
                            in_keys=["actual", "expected"],
                            value_key="actual",
                        )
                    ),
                    in_keys=["test_cases", "actual"],
                    value_key="actual"
                ),
                1.0
            )
            return SMC(3, 1, sampler, rng=np.random.RandomState(0),
                       to_key=Pick("interpreter_state"), max_try_num=1)
        else:
            sampler = SamplerWithValueNetwork(
                sampler,
                Sequence(OrderedDict([
                    ("tinput",
                     Apply(
                         module=TransformInputs(),
                         in_keys=["test_cases"],
                         out_key="test_case_tensor",
                     )),
                    ("tvariable",
                     Apply(
                         module=TransformVariables(),
                         in_keys=["variables", "test_case_tensor"],
                         out_key="variables_tensor"
                     )),
                ])),
                collate,
                torch.nn.Sequential(OrderedDict([
                    ("encoder", model.encoder),
                    ("value", model.value),
                    ("pick",
                     mlprogram.nn.Function(
                         Pick("value")))
                ])))

            synthesizer = SynthesizerWithTimeout(
                SMC(3, 1, sampler, rng=np.random.RandomState(0),
                    to_key=Pick("interpreter_state"),
                    max_try_num=1),
                1
            )
            return FilteredSynthesizer(
                synthesizer,
                metrics.use_environment(
                    metric=metrics.TestCaseResult(
                        interpreter,
                        metric=metrics.use_environment(
                            metric=metrics.Iou(),
                            in_keys=["actual", "expected"],
                            value_key="actual",
                        )
                    ),
                    in_keys=["test_cases", "actual"],
                    value_key="actual"
                ),
                1.0
            )
예제 #19
0
 def prepare_model(self, qencoder, aencoder):
     embedding = mlprogram.nn.action_sequence.ActionsEmbedding(
         aencoder._rule_encoder.vocab_size,
         aencoder._token_encoder.vocab_size,
         aencoder._node_type_encoder.vocab_size, 64, 256)
     decoder = nl2code.Decoder(embedding.output_size, 256, 256, 64, 0.0)
     return torch.nn.Sequential(
         OrderedDict([
             ("encoder",
              torch.nn.Sequential(
                  OrderedDict([
                      ("embedding",
                       Apply(module=mlprogram.nn.EmbeddingWithMask(
                           qencoder.vocab_size, 256, -1),
                             in_keys=[["word_nl_query", "x"]],
                             out_key="nl_features")),
                      ("lstm",
                       Apply(module=mlprogram.nn.BidirectionalLSTM(
                           256, 256, 0.0),
                             in_keys=[["nl_features", "x"]],
                             out_key="reference_features")),
                  ]))),
             ("decoder",
              torch.nn.Sequential(
                  OrderedDict([
                      ("embedding",
                       Apply(module=embedding,
                             in_keys=[
                                 "actions",
                                 "previous_actions",
                             ],
                             out_key="action_features")),
                      ("decoder",
                       Apply(
                           module=decoder,
                           in_keys=[
                               ["reference_features", "nl_query_features"],
                               "actions",
                               "action_features",
                               "history",
                               "hidden_state",
                               "state",
                           ],
                           out_key=[
                               "action_features",
                               "action_contexts",
                               "history",
                               "hidden_state",
                               "state",
                           ])),
                      ("predictor",
                       Apply(
                           module=nl2code.Predictor(embedding, 256, 256,
                                                    256, 64),
                           in_keys=[
                               "reference_features", "action_features",
                               "action_contexts"
                           ],
                           out_key=[
                               "rule_probs", "token_probs",
                               "reference_probs"
                           ],
                       ))
                  ])))
         ]))