def test_log_prob_padding(self): scores = torch.tensor( [ [1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0], ], requires_grad=True, ) shape = 2.0 frechet_sort = FrechetSort(topk=3, shape=shape, log_scores=True) # A shorter sequence should have a higher prob action = torch.tensor( [ [0, 1, 2, 3, 4], [0, 1, 5, 5, 5], ], dtype=torch.long, ) log_probs = frechet_sort.log_prob(scores, action) self.assertLess(log_probs[0], log_probs[1]) log_probs.sum().backward() self.assertGreater(scores.grad.sum(), 0) # manually calculating the log prob for the second case # 5 is padding, so we remove it here s = scores[1][action[1][:2]] log_prob = 0.0 for p in range(2): log_prob -= torch.exp((s[p:] - s[p]) * shape).sum().log() self.assertAlmostEqual(log_prob, log_probs[1])
def test_log_prob(self): scores = torch.tensor([ [1.0, 2.0, 3.0, 4.0, 5.0], [5.0, 1.0, 2.0, 3.0, 4.0], ]) shape = 2.0 frechet_sort = FrechetSort(topk=3, shape=shape, log_scores=True) # The log-prob should be the same; the last 2 positions don't matter action = torch.tensor( [ [0, 1, 2, 3, 4], [1, 2, 3, 0, 4], ], dtype=torch.long, ) log_probs = frechet_sort.log_prob(scores, action) self.assertEqual(log_probs[0], log_probs[1]) action = torch.tensor( [ [0, 1, 2, 3, 4], [3, 2, 1, 0, 4], ], dtype=torch.long, ) log_probs = frechet_sort.log_prob(scores, action) self.assertLess(log_probs[0], log_probs[1]) # manually calculating the log prob for the second case s = scores[1][action[1]] log_prob = 0.0 for p in range(3): log_prob -= torch.exp((s[p:] - s[p]) * shape).sum().log() self.assertAlmostEqual(log_prob, log_probs[1])
def test_ips_ratio_mean(self, output_arch, shape): output_arch = Seq2SlateOutputArch.FRECHET_SORT shape = 0.1 logger.info(f"output arch: {output_arch}") logger.info(f"frechet shape: {shape}") candidate_num = 5 candidate_dim = 2 state_dim = 1 hidden_size = 8 device = torch.device("cpu") batch_size = 1024 num_batches = 400 learning_rate = 0.001 policy_gradient_interval = 1 state = torch.zeros(batch_size, state_dim) # all data have same candidates candidates = torch.randint( 5, (batch_size, candidate_num, candidate_dim)).float() candidates[1:] = candidates[0] candidate_scores = torch.sum(candidates, dim=-1) seq2slate_params = Seq2SlateParameters(on_policy=False, ) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) sampler = FrechetSort(shape=shape, topk=candidate_num) sum_of_ips_ratio = 0 for i in range(num_batches): sample_outputs = [ sampler.sample_action(candidate_scores[j:j + 1]) for j in range(batch_size) ] action = torch.stack( list(map(lambda x: x.action.squeeze(0), sample_outputs))) logged_propensity = torch.stack( list(map(lambda x: torch.exp(x.log_prob), sample_outputs))) batch = rlt.PreprocessedRankingInput.from_input( state=state, candidates=candidates, device=device, action=action, logged_propensities=logged_propensity, ) model_propensities = torch.exp( seq2slate_net( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs) impt_smpl, _ = trainer._compute_impt_smpl(model_propensities, logged_propensity) sum_of_ips_ratio += torch.mean(impt_smpl).detach().numpy() mean_of_ips_ratio = sum_of_ips_ratio / (i + 1) logger.info(f"{i}-th batch, mean ips ratio={mean_of_ips_ratio}") if i > 100 and np.allclose(mean_of_ips_ratio, 1, atol=0.03): return raise Exception( f"Mean ips ratio {mean_of_ips_ratio} is not close to 1")
def test_compute_impt_smpl(self, output_arch, clamp_method, clamp_max, shape): logger.info(f"output arch: {output_arch}") logger.info(f"clamp method: {clamp_method}") logger.info(f"clamp max: {clamp_max}") logger.info(f"frechet shape: {shape}") candidate_num = 5 candidate_dim = 2 state_dim = 1 hidden_size = 32 device = torch.device("cpu") batch_size = 32 learning_rate = 0.001 policy_gradient_interval = 1 candidates = torch.randint(5, (candidate_num, candidate_dim)).float() candidate_scores = torch.sum(candidates, dim=1) seq2slate_params = Seq2SlateParameters( on_policy=False, ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=clamp_max), ) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) all_permt = torch.tensor( list(permutations(range(candidate_num), candidate_num))) sampler = FrechetSort(shape=shape, topk=candidate_num) sum_of_logged_propensity = 0 sum_of_model_propensity = 0 sum_of_ips_ratio = 0 for i in range(len(all_permt)): sample_action = all_permt[i] logged_propensity = torch.exp( sampler.log_prob(candidate_scores, sample_action)) batch = rlt.PreprocessedRankingInput.from_input( state=torch.zeros(1, state_dim), candidates=candidates.unsqueeze(0), device=device, action=sample_action.unsqueeze(0), logged_propensities=logged_propensity.reshape(1, 1), ) model_propensities = torch.exp( seq2slate_net( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs) impt_smpl, clamped_impt_smpl = trainer._compute_impt_smpl( model_propensities, logged_propensity) if impt_smpl > clamp_max: if clamp_method == IPSClampMethod.AGGRESSIVE: npt.asset_allclose(clamped_impt_smpl.detach().numpy(), 0, rtol=1e-5) else: npt.assert_allclose(clamped_impt_smpl.detach().numpy(), clamp_max, rtol=1e-5) sum_of_model_propensity += model_propensities sum_of_logged_propensity += logged_propensity sum_of_ips_ratio += model_propensities / logged_propensity logger.info( f"shape={shape}, sample_action={sample_action}, logged_propensity={logged_propensity}," f" model_propensity={model_propensities}") logger.info( f"shape {shape}, sum_of_logged_propensity={sum_of_logged_propensity}, " f"sum_of_model_propensity={sum_of_model_propensity}, " f"mean sum_of_ips_ratio={sum_of_ips_ratio / len(all_permt)}") npt.assert_allclose(sum_of_logged_propensity.detach().numpy(), 1, rtol=1e-5) npt.assert_allclose(sum_of_model_propensity.detach().numpy(), 1, rtol=1e-5)