def reinforce(self, train_dataset, encoder, output_dir): interpreter = self.interpreter() collate = Collate( test_case_tensor=CollateOptions(False, 0, 0), variables_tensor=CollateOptions(True, 0, 0), previous_actions=CollateOptions(True, 0, -1), hidden_state=CollateOptions(False, 0, 0), state=CollateOptions(False, 0, 0), ground_truth_actions=CollateOptions(True, 0, -1), reward=CollateOptions(False, 0, 0) ) collate_fn = Sequence(OrderedDict([ ("to_episode", Map(self.to_episode(encoder, interpreter))), ("flatten", Flatten()), ("transform", Map(self.transform( encoder, interpreter, Parser()))), ("collate", collate.collate) ])) model = self.prepare_model(encoder) optimizer = self.prepare_optimizer(model) train_REINFORCE( os.path.join(output_dir, "pretrain"), os.path.join(output_dir, "RL"), train_dataset, self.prepare_synthesizer(model, encoder, interpreter), model, optimizer, torch.nn.Sequential(OrderedDict([ ("policy", torch.nn.Sequential(OrderedDict([ ("loss", Apply( module=mlprogram.nn.action_sequence.Loss( reduction="none" ), in_keys=[ "rule_probs", "token_probs", "reference_probs", "ground_truth_actions", ], out_key="action_sequence_loss", )), ("weight_by_reward", Apply( [("reward", "lhs"), ("action_sequence_loss", "rhs")], "action_sequence_loss", mlprogram.nn.Function(Mul()))) ]))), ("value", torch.nn.Sequential(OrderedDict([ ("reshape_reward", Apply( [("reward", "x")], "value_loss_target", Reshape([-1, 1]))), ("BCE", Apply( [("value", "input"), ("value_loss_target", "target")], "value_loss", torch.nn.BCELoss(reduction='sum'))), ("reweight", Apply( [("value_loss", "lhs")], "value_loss", mlprogram.nn.Function(Mul()), constants={"rhs": 1e-2})), ]))), ("aggregate", Apply( ["action_sequence_loss", "value_loss"], "loss", AggregatedLoss())), ("normalize", Apply( [("loss", "lhs")], "loss", mlprogram.nn.Function(Div()), constants={"rhs": 1})), ("pick", mlprogram.nn.Function( Pick("loss"))) ])), EvaluateSynthesizer( train_dataset, self.prepare_synthesizer(model, encoder, interpreter, rollout=False), {}, top_n=[]), "generation_rate", metrics.use_environment( metric=metrics.TestCaseResult( interpreter=interpreter, metric=metrics.use_environment( metric=metrics.Iou(), in_keys=["actual", "expected"], value_key="actual", ) ), in_keys=["test_cases", "actual"], value_key="actual", transform=Threshold(threshold=0.9, dtype="float"), ), collate_fn, 1, 1, Epoch(10), evaluation_interval=Epoch(10), snapshot_interval=Epoch(10), use_pretrained_model=True, use_pretrained_optimizer=True, threshold=1.0)
def prepare_synthesizer(self, model, encoder, interpreter, rollout=True): collate = Collate( test_case_tensor=CollateOptions(False, 0, 0), input_feature=CollateOptions(False, 0, 0), test_case_feature=CollateOptions(False, 0, 0), reference_features=CollateOptions(True, 0, 0), variables_tensor=CollateOptions(True, 0, 0), previous_actions=CollateOptions(True, 0, -1), hidden_state=CollateOptions(False, 0, 0), state=CollateOptions(False, 0, 0), ground_truth_actions=CollateOptions(True, 0, -1) ) subsampler = ActionSequenceSampler( encoder, IsSubtype(), Sequence(OrderedDict([ ("tinput", Apply( module=TransformInputs(), in_keys=["test_cases"], out_key="test_case_tensor", )), ("tvariable", Apply( module=TransformVariables(), in_keys=["variables", "test_case_tensor"], out_key="variables_tensor" )), ])), Compose(OrderedDict([ ("add_previous_actions", Apply( module=AddPreviousActions(encoder, n_dependent=1), in_keys=["action_sequence", "reference"], out_key="previous_actions", constants={"train": False}, )), ("add_state", AddState("state")), ("add_hidden_state", AddState("hidden_state")) ])), collate, model, rng=np.random.RandomState(0)) subsampler = mlprogram.samplers.transform( subsampler, Parser().unparse ) subsynthesizer = SMC( 5, 1, subsampler, max_try_num=1, to_key=Pick("action_sequence"), rng=np.random.RandomState(0) ) sampler = SequentialProgramSampler( subsynthesizer, Apply( module=TransformInputs(), in_keys=["test_cases"], out_key="test_case_tensor", ), collate, model.encode_input, interpreter=interpreter, expander=Expander(), rng=np.random.RandomState(0)) if rollout: sampler = FilteredSampler( sampler, metrics.use_environment( metric=metrics.TestCaseResult( interpreter, metric=metrics.use_environment( metric=metrics.Iou(), in_keys=["actual", "expected"], value_key="actual", ) ), in_keys=["test_cases", "actual"], value_key="actual" ), 1.0 ) return SMC(3, 1, sampler, rng=np.random.RandomState(0), to_key=Pick("interpreter_state"), max_try_num=1) else: sampler = SamplerWithValueNetwork( sampler, Sequence(OrderedDict([ ("tinput", Apply( module=TransformInputs(), in_keys=["test_cases"], out_key="test_case_tensor", )), ("tvariable", Apply( module=TransformVariables(), in_keys=["variables", "test_case_tensor"], out_key="variables_tensor" )), ])), collate, torch.nn.Sequential(OrderedDict([ ("encoder", model.encoder), ("value", model.value), ("pick", mlprogram.nn.Function( Pick("value"))) ]))) synthesizer = SynthesizerWithTimeout( SMC(3, 1, sampler, rng=np.random.RandomState(0), to_key=Pick("interpreter_state"), max_try_num=1), 1 ) return FilteredSynthesizer( synthesizer, metrics.use_environment( metric=metrics.TestCaseResult( interpreter, metric=metrics.use_environment( metric=metrics.Iou(), in_keys=["actual", "expected"], value_key="actual", ) ), in_keys=["test_cases", "actual"], value_key="actual" ), 1.0 )
def pretrain(self, output_dir): dataset = Dataset(4, 1, 2, 1, 45, seed=0) """ """ train_dataset = ListDataset([ Environment( {"ground_truth": Circle(1)}, set(["ground_truth"]), ), Environment( {"ground_truth": Rectangle(1, 2)}, set(["ground_truth"]), ), Environment( {"ground_truth": Rectangle(1, 1)}, set(["ground_truth"]), ), Environment( {"ground_truth": Rotation(45, Rectangle(1, 1))}, set(["ground_truth"]), ), Environment( {"ground_truth": Translation(1, 1, Rectangle(1, 1))}, set(["ground_truth"]), ), Environment( {"ground_truth": Difference(Circle(1), Circle(1))}, set(["ground_truth"]), ), Environment( {"ground_truth": Union(Rectangle(1, 2), Circle(1))}, set(["ground_truth"]), ), Environment( {"ground_truth": Difference(Rectangle(1, 1), Circle(1))}, set(["ground_truth"]), ), ]) interpreter = self.interpreter() train_dataset = data_transform( train_dataset, Apply( module=AddTestCases(interpreter), in_keys=["ground_truth"], out_key="test_cases", is_out_supervision=False, ) ) encoder = self.prepare_encoder(dataset, Parser()) collate = Collate( test_case_tensor=CollateOptions(False, 0, 0), variables_tensor=CollateOptions(True, 0, 0), previous_actions=CollateOptions(True, 0, -1), hidden_state=CollateOptions(False, 0, 0), state=CollateOptions(False, 0, 0), ground_truth_actions=CollateOptions(True, 0, -1) ) collate_fn = Sequence(OrderedDict([ ("to_episode", Map(self.to_episode(encoder, interpreter))), ("flatten", Flatten()), ("transform", Map(self.transform( encoder, interpreter, Parser()))), ("collate", collate.collate) ])) model = self.prepare_model(encoder) optimizer = self.prepare_optimizer(model) train_supervised( os.path.join(output_dir, "pretrain"), train_dataset, model, optimizer, torch.nn.Sequential(OrderedDict([ ("loss", Apply( module=Loss( reduction="sum", ), in_keys=[ "rule_probs", "token_probs", "reference_probs", "ground_truth_actions", ], out_key="action_sequence_loss", )), ("normalize", # divided by batch_size Apply( [("action_sequence_loss", "lhs")], "loss", mlprogram.nn.Function(Div()), constants={"rhs": 1})), ("pick", mlprogram.nn.Function( Pick("loss"))) ])), None, "score", collate_fn, 1, Epoch(100), evaluation_interval=Epoch(10), snapshot_interval=Epoch(100) ) return encoder, train_dataset
def parser(): return Parser()