def create_trainer(seq2slate_net, learning_method, batch_size, learning_rate, device): use_gpu = False if device == torch.device("cpu") else True if learning_method == ON_POLICY: seq2slate_params = Seq2SlateParameters( on_policy=True, learning_method=LearningMethod.REINFORCEMENT_LEARNING) trainer_cls = Seq2SlateTrainer elif learning_method == SIMULATION: temp_reward_model_path = tempfile.mkstemp(suffix=".pt")[1] reward_model = torch.jit.script(TSPRewardModel()) torch.jit.save(reward_model, temp_reward_model_path) seq2slate_params = Seq2SlateParameters( on_policy=True, learning_method=LearningMethod.SIMULATION, simulation=SimulationParameters( reward_name_weight={"tour_length": 1.0}, reward_name_path={"tour_length": temp_reward_model_path}, ), ) trainer_cls = Seq2SlateSimulationTrainer param_dict = { "seq2slate_net": seq2slate_net, "minibatch_size": batch_size, "parameters": seq2slate_params, "policy_optimizer": Optimizer__Union.default(lr=learning_rate), "use_gpu": use_gpu, "print_interval": 100, } return trainer_cls(**param_dict)
def _test_seq2slate_trainer_off_policy(self, policy_gradient_interval, output_arch, device): batch_size = 32 state_dim = 2 candidate_num = 15 candidate_dim = 4 hidden_size = 16 learning_rate = 1.0 on_policy = False seq2slate_params = Seq2SlateParameters(on_policy=on_policy) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) seq2slate_net_copy = copy.deepcopy(seq2slate_net) seq2slate_net_copy_copy = copy.deepcopy(seq2slate_net) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) batch = create_off_policy_batch(seq2slate_net, batch_size, state_dim, candidate_num, candidate_dim, device) for _ in range(policy_gradient_interval): trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch)) # manual compute gradient ranked_per_seq_log_probs = seq2slate_net_copy( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs loss = -(torch.mean(ranked_per_seq_log_probs * torch.exp(ranked_per_seq_log_probs).detach() / batch.tgt_out_probs * batch.slate_reward)) loss.backward() self.assert_correct_gradient(seq2slate_net_copy, seq2slate_net, policy_gradient_interval, learning_rate) # another way to compute gradient manually ranked_per_seq_probs = torch.exp( seq2slate_net_copy_copy( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs) loss = -(torch.mean( ranked_per_seq_probs / batch.tgt_out_probs * batch.slate_reward)) loss.backward() self.assert_correct_gradient( seq2slate_net_copy_copy, seq2slate_net, policy_gradient_interval, learning_rate, )
def create_trainer(seq2slate_net, batch_size, learning_rate, device, on_policy): use_gpu = False if device == torch.device("cpu") else True return Seq2SlateTrainer( seq2slate_net=seq2slate_net, minibatch_size=batch_size, parameters=Seq2SlateParameters(on_policy=on_policy), policy_optimizer=Optimizer__Union.default(lr=learning_rate), use_gpu=use_gpu, print_interval=100, )
def test_seq2slate_trainer_off_policy_with_clamp(self, clamp_method, output_arch): batch_size = 32 state_dim = 2 candidate_num = 15 candidate_dim = 4 hidden_size = 16 learning_rate = 1.0 device = torch.device("cpu") policy_gradient_interval = 1 seq2slate_params = Seq2SlateParameters( on_policy=False, ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=0.3), ) seq2slate_net = create_seq2slate_transformer( state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device ) seq2slate_net_copy = copy.deepcopy(seq2slate_net) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) batch = create_off_policy_batch( seq2slate_net, batch_size, state_dim, candidate_num, candidate_dim, device ) for _ in range(policy_gradient_interval): trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch)) # manual compute gradient ranked_per_seq_probs = torch.exp( seq2slate_net_copy( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE ).log_probs ) logger.info(f"ips ratio={ranked_per_seq_probs / batch.tgt_out_probs}") loss = -( torch.mean( ips_clamp( ranked_per_seq_probs / batch.tgt_out_probs, seq2slate_params.ips_clamp, ) * batch.slate_reward ) ) loss.backward() self.assert_correct_gradient( seq2slate_net_copy, seq2slate_net, policy_gradient_interval, learning_rate )
def test_ips_ratio_mean(self, output_arch, shape): output_arch = Seq2SlateOutputArch.FRECHET_SORT shape = 0.1 logger.info(f"output arch: {output_arch}") logger.info(f"frechet shape: {shape}") candidate_num = 5 candidate_dim = 2 state_dim = 1 hidden_size = 8 device = torch.device("cpu") batch_size = 1024 num_batches = 400 learning_rate = 0.001 policy_gradient_interval = 1 state = torch.zeros(batch_size, state_dim) # all data have same candidates candidates = torch.randint( 5, (batch_size, candidate_num, candidate_dim)).float() candidates[1:] = candidates[0] candidate_scores = torch.sum(candidates, dim=-1) seq2slate_params = Seq2SlateParameters(on_policy=False, ) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) sampler = FrechetSort(shape=shape, topk=candidate_num) sum_of_ips_ratio = 0 for i in range(num_batches): sample_outputs = [ sampler.sample_action(candidate_scores[j:j + 1]) for j in range(batch_size) ] action = torch.stack( list(map(lambda x: x.action.squeeze(0), sample_outputs))) logged_propensity = torch.stack( list(map(lambda x: torch.exp(x.log_prob), sample_outputs))) batch = rlt.PreprocessedRankingInput.from_input( state=state, candidates=candidates, device=device, action=action, logged_propensities=logged_propensity, ) model_propensities = torch.exp( seq2slate_net( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs) impt_smpl, _ = trainer._compute_impt_smpl(model_propensities, logged_propensity) sum_of_ips_ratio += torch.mean(impt_smpl).detach().numpy() mean_of_ips_ratio = sum_of_ips_ratio / (i + 1) logger.info(f"{i}-th batch, mean ips ratio={mean_of_ips_ratio}") if i > 100 and np.allclose(mean_of_ips_ratio, 1, atol=0.03): return raise Exception( f"Mean ips ratio {mean_of_ips_ratio} is not close to 1")
def test_compute_impt_smpl(self, output_arch, clamp_method, clamp_max, shape): logger.info(f"output arch: {output_arch}") logger.info(f"clamp method: {clamp_method}") logger.info(f"clamp max: {clamp_max}") logger.info(f"frechet shape: {shape}") candidate_num = 5 candidate_dim = 2 state_dim = 1 hidden_size = 32 device = torch.device("cpu") batch_size = 32 learning_rate = 0.001 policy_gradient_interval = 1 candidates = torch.randint(5, (candidate_num, candidate_dim)).float() candidate_scores = torch.sum(candidates, dim=1) seq2slate_params = Seq2SlateParameters( on_policy=False, ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=clamp_max), ) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) all_permt = torch.tensor( list(permutations(range(candidate_num), candidate_num))) sampler = FrechetSort(shape=shape, topk=candidate_num) sum_of_logged_propensity = 0 sum_of_model_propensity = 0 sum_of_ips_ratio = 0 for i in range(len(all_permt)): sample_action = all_permt[i] logged_propensity = torch.exp( sampler.log_prob(candidate_scores, sample_action)) batch = rlt.PreprocessedRankingInput.from_input( state=torch.zeros(1, state_dim), candidates=candidates.unsqueeze(0), device=device, action=sample_action.unsqueeze(0), logged_propensities=logged_propensity.reshape(1, 1), ) model_propensities = torch.exp( seq2slate_net( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs) impt_smpl, clamped_impt_smpl = trainer._compute_impt_smpl( model_propensities, logged_propensity) if impt_smpl > clamp_max: if clamp_method == IPSClampMethod.AGGRESSIVE: npt.asset_allclose(clamped_impt_smpl.detach().numpy(), 0, rtol=1e-5) else: npt.assert_allclose(clamped_impt_smpl.detach().numpy(), clamp_max, rtol=1e-5) sum_of_model_propensity += model_propensities sum_of_logged_propensity += logged_propensity sum_of_ips_ratio += model_propensities / logged_propensity logger.info( f"shape={shape}, sample_action={sample_action}, logged_propensity={logged_propensity}," f" model_propensity={model_propensities}") logger.info( f"shape {shape}, sum_of_logged_propensity={sum_of_logged_propensity}, " f"sum_of_model_propensity={sum_of_model_propensity}, " f"mean sum_of_ips_ratio={sum_of_ips_ratio / len(all_permt)}") npt.assert_allclose(sum_of_logged_propensity.detach().numpy(), 1, rtol=1e-5) npt.assert_allclose(sum_of_model_propensity.detach().numpy(), 1, rtol=1e-5)
def _test_seq2slate_trainer_on_policy(self, policy_gradient_interval, output_arch, device): batch_size = 32 state_dim = 2 candidate_num = 15 candidate_dim = 4 hidden_size = 16 learning_rate = 1.0 on_policy = True rank_seed = 111 seq2slate_params = Seq2SlateParameters(on_policy=on_policy) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) seq2slate_net_copy = copy.deepcopy(seq2slate_net) seq2slate_net_copy_copy = copy.deepcopy(seq2slate_net) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) batch = create_on_policy_batch( seq2slate_net, batch_size, state_dim, candidate_num, candidate_dim, rank_seed, device, ) for _ in range(policy_gradient_interval): trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch)) # manual compute gradient torch.manual_seed(rank_seed) rank_output = seq2slate_net_copy(batch, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=candidate_num, greedy=False) loss = -(torch.mean( torch.log(rank_output.ranked_per_seq_probs) * batch.slate_reward)) loss.backward() self.assert_correct_gradient(seq2slate_net_copy, seq2slate_net, policy_gradient_interval, learning_rate) # another way to compute gradient manually torch.manual_seed(rank_seed) ranked_per_seq_probs = seq2slate_net_copy_copy( batch, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=candidate_num, greedy=False).ranked_per_seq_probs loss = -(torch.mean( ranked_per_seq_probs / ranked_per_seq_probs.detach() * batch.slate_reward)) loss.backward() self.assert_correct_gradient( seq2slate_net_copy_copy, seq2slate_net, policy_gradient_interval, learning_rate, )