def create_trainer(seq2slate_net, learning_method, batch_size, learning_rate, device): use_gpu = False if device == torch.device("cpu") else True if learning_method == ON_POLICY: seq2slate_params = Seq2SlateParameters( on_policy=True, learning_method=LearningMethod.REINFORCEMENT_LEARNING) trainer_cls = Seq2SlateTrainer elif learning_method == SIMULATION: temp_reward_model_path = tempfile.mkstemp(suffix=".pt")[1] reward_model = torch.jit.script(TSPRewardModel()) torch.jit.save(reward_model, temp_reward_model_path) seq2slate_params = Seq2SlateParameters( on_policy=True, learning_method=LearningMethod.SIMULATION, simulation=SimulationParameters( reward_name_weight={"tour_length": 1.0}, reward_name_path={"tour_length": temp_reward_model_path}, ), ) trainer_cls = Seq2SlateSimulationTrainer param_dict = { "seq2slate_net": seq2slate_net, "minibatch_size": batch_size, "parameters": seq2slate_params, "policy_optimizer": Optimizer__Union.default(lr=learning_rate), "use_gpu": use_gpu, "print_interval": 100, } return trainer_cls(**param_dict)
def create_trainer(seq2slate_net, batch_size, learning_rate, device, on_policy): use_gpu = False if device == torch.device("cpu") else True return Seq2SlateTrainer( seq2slate_net=seq2slate_net, minibatch_size=batch_size, parameters=Seq2SlateParameters(on_policy=on_policy), policy_optimizer=Optimizer__Union.default(lr=learning_rate), use_gpu=use_gpu, print_interval=100, )