def test_ips_clamp(self): importance_sampling = torch.tensor([0.5, 0.3, 3.0, 10.0, 40.0]) assert torch.all(ips_clamp(importance_sampling, None) == importance_sampling) assert torch.all( ips_clamp(importance_sampling, IPSClamp(IPSClampMethod.AGGRESSIVE, 3.0)) == torch.tensor([0.5, 0.3, 3.0, 0.0, 0.0]) ) assert torch.all( ips_clamp(importance_sampling, IPSClamp(IPSClampMethod.UNIVERSAL, 3.0)) == torch.tensor([0.5, 0.3, 3.0, 3.0, 3.0]) )
def test_seq2slate_trainer_off_policy_with_clamp(self, clamp_method, output_arch): batch_size = 32 state_dim = 2 candidate_num = 15 candidate_dim = 4 hidden_size = 16 learning_rate = 1.0 device = torch.device("cpu") policy_gradient_interval = 1 seq2slate_params = Seq2SlateParameters( on_policy=False, ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=0.3), ) seq2slate_net = create_seq2slate_transformer( state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device ) seq2slate_net_copy = copy.deepcopy(seq2slate_net) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) batch = create_off_policy_batch( seq2slate_net, batch_size, state_dim, candidate_num, candidate_dim, device ) for _ in range(policy_gradient_interval): trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch)) # manual compute gradient ranked_per_seq_probs = torch.exp( seq2slate_net_copy( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE ).log_probs ) logger.info(f"ips ratio={ranked_per_seq_probs / batch.tgt_out_probs}") loss = -( torch.mean( ips_clamp( ranked_per_seq_probs / batch.tgt_out_probs, seq2slate_params.ips_clamp, ) * batch.slate_reward ) ) loss.backward() self.assert_correct_gradient( seq2slate_net_copy, seq2slate_net, policy_gradient_interval, learning_rate )
def test_compute_impt_smpl(self, output_arch, clamp_method, clamp_max, shape): logger.info(f"output arch: {output_arch}") logger.info(f"clamp method: {clamp_method}") logger.info(f"clamp max: {clamp_max}") logger.info(f"frechet shape: {shape}") candidate_num = 5 candidate_dim = 2 state_dim = 1 hidden_size = 32 device = torch.device("cpu") batch_size = 32 learning_rate = 0.001 policy_gradient_interval = 1 candidates = torch.randint(5, (candidate_num, candidate_dim)).float() candidate_scores = torch.sum(candidates, dim=1) seq2slate_params = Seq2SlateParameters( on_policy=False, ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=clamp_max), ) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) all_permt = torch.tensor( list(permutations(range(candidate_num), candidate_num))) sampler = FrechetSort(shape=shape, topk=candidate_num) sum_of_logged_propensity = 0 sum_of_model_propensity = 0 sum_of_ips_ratio = 0 for i in range(len(all_permt)): sample_action = all_permt[i] logged_propensity = torch.exp( sampler.log_prob(candidate_scores, sample_action)) batch = rlt.PreprocessedRankingInput.from_input( state=torch.zeros(1, state_dim), candidates=candidates.unsqueeze(0), device=device, action=sample_action.unsqueeze(0), logged_propensities=logged_propensity.reshape(1, 1), ) model_propensities = torch.exp( seq2slate_net( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs) impt_smpl, clamped_impt_smpl = trainer._compute_impt_smpl( model_propensities, logged_propensity) if impt_smpl > clamp_max: if clamp_method == IPSClampMethod.AGGRESSIVE: npt.asset_allclose(clamped_impt_smpl.detach().numpy(), 0, rtol=1e-5) else: npt.assert_allclose(clamped_impt_smpl.detach().numpy(), clamp_max, rtol=1e-5) sum_of_model_propensity += model_propensities sum_of_logged_propensity += logged_propensity sum_of_ips_ratio += model_propensities / logged_propensity logger.info( f"shape={shape}, sample_action={sample_action}, logged_propensity={logged_propensity}," f" model_propensity={model_propensities}") logger.info( f"shape {shape}, sum_of_logged_propensity={sum_of_logged_propensity}, " f"sum_of_model_propensity={sum_of_model_propensity}, " f"mean sum_of_ips_ratio={sum_of_ips_ratio / len(all_permt)}") npt.assert_allclose(sum_of_logged_propensity.detach().numpy(), 1, rtol=1e-5) npt.assert_allclose(sum_of_model_propensity.detach().numpy(), 1, rtol=1e-5)