def test_seq2slate_transformer_propensity_computation( self, output_arch, temperature): """ Test propensity computation of seq2slate net """ candidate_num = 4 candidate_dim = 2 hidden_size = 32 all_perm = torch.tensor( list(permutations(torch.arange(candidate_num), candidate_num))) batch_size = len(all_perm) device = torch.device("cpu") seq2slate_net = create_seq2slate_net( MODEL_TRANSFORMER, candidate_num, candidate_dim, hidden_size, output_arch, temperature, device, ) batch = create_batch( batch_size, candidate_num, candidate_dim, device, ON_POLICY, diverse_input=False, ) batch = rlt.PreprocessedRankingInput.from_input( state=batch.state.float_features, candidates=batch.src_seq.float_features, device=device, action=all_perm, ) per_symbol_log_prob = seq2slate_net( batch, mode=Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE).log_probs per_seq_log_prob = seq2slate_net( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs per_seq_log_prob_computed = per_symbol_to_per_seq_log_probs( per_symbol_log_prob, all_perm + 2) # probabilities of two modes should match np.testing.assert_allclose(per_seq_log_prob, per_seq_log_prob_computed, atol=0.00001) # probabilities of all possible permutations should sum up to 1 np.testing.assert_allclose(torch.sum(torch.exp(per_seq_log_prob)), 1.0, atol=0.00001)
def test_seq2slate_transformer_onpolicy_basic_logic( self, output_arch, temperature): """ Test basic logic of seq2slate on policy sampling """ device = torch.device("cpu") candidate_num = 4 candidate_dim = 2 batch_size = 4096 hidden_size = 32 seq2slate_net = create_seq2slate_net( MODEL_TRANSFORMER, candidate_num, candidate_dim, hidden_size, output_arch, temperature, device, ) batch = create_batch( batch_size, candidate_num, candidate_dim, device, ON_POLICY, diverse_input=False, ) action_to_propensity_map = {} action_count = defaultdict(int) total_count = 0 for i in range(50): model_propensity, model_action = rank_on_policy(seq2slate_net, batch, candidate_num, greedy=False) for propensity, action in zip(model_propensity, model_action): action_str = ",".join(map(str, action.numpy().tolist())) # Same action always leads to same propensity if action_to_propensity_map.get(action_str) is None: action_to_propensity_map[action_str] = float(propensity) else: np.testing.assert_allclose( action_to_propensity_map[action_str], float(propensity), atol=0.001, rtol=0.0, ) action_count[action_str] += 1 total_count += 1 logger.info(f"Finish {i} round, {total_count} data counts") # Check action distribution for action_str, count in action_count.items(): empirical_propensity = count / total_count computed_propensity = action_to_propensity_map[action_str] logger.info( f"action={action_str}, empirical propensity={empirical_propensity}, " f"computed propensity={computed_propensity}") np.testing.assert_allclose(computed_propensity, empirical_propensity, atol=0.01, rtol=0.0)