def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action).float() reward = torch.tensor(reward).unsqueeze(1) next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action) not_terinal = 1.0 - torch.tensor(terminal).unsqueeze(1).float() idxs = torch.tensor(idxs) possible_actions_mask = torch.tensor(possible_actions_mask).float() log_prob = torch.tensor(log_prob) return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedPolicyNetworkInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=rlt.PreprocessedFeatureVector(float_features=action), next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=rlt.PreprocessedFeatureVector( float_features=next_action), reward=reward, not_terminal=not_terinal, step=None, time_diff=None, ), extras=rlt.ExtraData(), )
def as_parametric_maxq_training_batch(self): state_dim = self.states.shape[1] return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedParametricDqnInput( state=rlt.FeatureData(float_features=self.states), action=rlt.FeatureData(float_features=self.actions), next_state=rlt.FeatureData(float_features=self.next_states), next_action=rlt.FeatureData(float_features=self.next_actions), tiled_next_state=rlt.FeatureData( float_features=self.possible_next_actions_state_concat[ :, :state_dim ] ), possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=rlt.FeatureData( float_features=self.possible_next_actions_state_concat[ :, state_dim: ] ), possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def train(self, training_batch: rlt.PreprocessedTrainingBatch): assert type(training_batch) is rlt.PreprocessedTrainingBatch training_input = training_batch.training_input assert isinstance(training_input, rlt.PreprocessedRankingInput) batch_size = training_input.state.float_features.shape[0] # randomly pick a permutation for every slate random_indices = torch.randint(0, len(self.permutation_index), (batch_size,)) sim_tgt_out_idx = self.permutation_index[random_indices] + 2 if self.parameters.simulation_distance_penalty is not None: sim_distance = self.permutation_distance[random_indices] else: sim_distance = None with torch.no_grad(): # format data according to the new ordering training_input = self._simulated_training_input( training_input, sim_tgt_out_idx, sim_distance, self.device ) return self.trainer.train( rlt.PreprocessedTrainingBatch( training_input=training_input, extras=training_batch.extras ) )
def run_seq2slate_tsp( model_str, batch_size, epochs, candidate_num, num_batches, hidden_size, diverse_input, learning_rate, expect_reward_threshold, learning_method, device, ): candidate_dim = 2 eval_sample_size = 1 train_batches, test_batch = create_train_and_test_batches( batch_size, candidate_num, candidate_dim, device, num_batches, diverse_input) best_test_possible_reward = compute_best_reward( test_batch.src_seq.float_features) seq2slate_net = create_seq2slate_net( model_str, candidate_num, candidate_dim, hidden_size, Seq2SlateOutputArch.AUTOREGRESSIVE, 1.0, device, ) trainer = create_trainer(seq2slate_net, learning_method, batch_size, learning_rate, device) for e in range(epochs): # training for batch in train_batches: batch = post_preprocess_batch(learning_method, seq2slate_net, candidate_num, batch, device, e) trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch)) # evaluation best_test_reward = torch.full((batch_size, ), 1e9).to(device) for _ in range(eval_sample_size): model_propensities, _, reward = rank_on_policy_and_eval( seq2slate_net, test_batch, candidate_num, greedy=True) best_test_reward = torch.where(reward < best_test_reward, reward, best_test_reward) logger.info( f"Test mean model_propensities {torch.mean(model_propensities)}, " f"Test mean reward: {torch.mean(best_test_reward)}, " f"best possible reward {best_test_possible_reward}") if (torch.mean(best_test_reward) < best_test_possible_reward * expect_reward_threshold): return raise AssertionError( "Test failed because it did not reach expected test reward")
def train(self, training_batch: rlt.PreprocessedTrainingBatch): assert type(training_batch) is rlt.PreprocessedTrainingBatch training_input = training_batch.training_input assert isinstance(training_input, rlt.PreprocessedRankingInput) training_input = self._simulated_training_input(training_input) return self.trainer.train( rlt.PreprocessedTrainingBatch(training_input=training_input, extras=training_batch.extras))
def _test_seq2slate_trainer_off_policy(self, policy_gradient_interval, output_arch, device): batch_size = 32 state_dim = 2 candidate_num = 15 candidate_dim = 4 hidden_size = 16 learning_rate = 1.0 on_policy = False seq2slate_params = Seq2SlateParameters(on_policy=on_policy) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) seq2slate_net_copy = copy.deepcopy(seq2slate_net) seq2slate_net_copy_copy = copy.deepcopy(seq2slate_net) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) batch = create_off_policy_batch(seq2slate_net, batch_size, state_dim, candidate_num, candidate_dim, device) for _ in range(policy_gradient_interval): trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch)) # manual compute gradient ranked_per_seq_log_probs = seq2slate_net_copy( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs loss = -(torch.mean(ranked_per_seq_log_probs * torch.exp(ranked_per_seq_log_probs).detach() / batch.tgt_out_probs * batch.slate_reward)) loss.backward() self.assert_correct_gradient(seq2slate_net_copy, seq2slate_net, policy_gradient_interval, learning_rate) # another way to compute gradient manually ranked_per_seq_probs = torch.exp( seq2slate_net_copy_copy( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE).log_probs) loss = -(torch.mean( ranked_per_seq_probs / batch.tgt_out_probs * batch.slate_reward)) loss.backward() self.assert_correct_gradient( seq2slate_net_copy_copy, seq2slate_net, policy_gradient_interval, learning_rate, )
def test_seq2slate_trainer_off_policy_with_clamp(self, clamp_method, output_arch): batch_size = 32 state_dim = 2 candidate_num = 15 candidate_dim = 4 hidden_size = 16 learning_rate = 1.0 device = torch.device("cpu") policy_gradient_interval = 1 seq2slate_params = Seq2SlateParameters( on_policy=False, ips_clamp=IPSClamp(clamp_method=clamp_method, clamp_max=0.3), ) seq2slate_net = create_seq2slate_transformer( state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device ) seq2slate_net_copy = copy.deepcopy(seq2slate_net) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) batch = create_off_policy_batch( seq2slate_net, batch_size, state_dim, candidate_num, candidate_dim, device ) for _ in range(policy_gradient_interval): trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch)) # manual compute gradient ranked_per_seq_probs = torch.exp( seq2slate_net_copy( batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE ).log_probs ) logger.info(f"ips ratio={ranked_per_seq_probs / batch.tgt_out_probs}") loss = -( torch.mean( ips_clamp( ranked_per_seq_probs / batch.tgt_out_probs, seq2slate_params.ips_clamp, ) * batch.slate_reward ) ) loss.backward() self.assert_correct_gradient( seq2slate_net_copy, seq2slate_net, policy_gradient_interval, learning_rate )
def as_slate_q_training_batch(self): batch_size, state_dim = self.states.shape action_dim = self.actions.shape[1] return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedSlateQInput( state=rlt.PreprocessedFeatureVector( float_features=self.states), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states), tiled_state=rlt.PreprocessedTiledFeatureVector( float_features=self. possible_actions_state_concat[:, :state_dim].view( batch_size, -1, state_dim)), tiled_next_state=rlt.PreprocessedTiledFeatureVector( float_features=self. possible_next_actions_state_concat[:, :state_dim].view( batch_size, -1, state_dim)), action=rlt.PreprocessedSlateFeatureVector( float_features=self. possible_actions_state_concat[:, state_dim:].view( batch_size, -1, action_dim), item_mask=self.possible_actions_mask, item_probability=self.propensities, ), next_action=rlt.PreprocessedSlateFeatureVector( float_features=self. possible_next_actions_state_concat[:, state_dim:].view( batch_size, -1, action_dim), item_mask=self.possible_next_actions_mask, item_probability=self.next_propensities, ), reward=self.rewards, reward_mask=self.rewards_mask, time_diff=self.time_diffs, step=self.step, not_terminal=self.not_terminal, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def as_policy_network_training_batch(self): return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedPolicyNetworkInput( state=rlt.PreprocessedFeatureVector( float_features=self.states), action=rlt.PreprocessedFeatureVector( float_features=self.actions), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states), next_action=rlt.PreprocessedFeatureVector( float_features=self.next_actions), reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def train(self, training_batch: rlt.PreprocessedTrainingBatch): assert type(training_batch) is rlt.PreprocessedTrainingBatch training_input = training_batch.training_input assert isinstance(training_input, rlt.PreprocessedRankingInput) batch_size = training_input.state.float_features.shape[0] # randomly pick a permutation for every slate random_indices = torch.randint(0, len(self.permutation_index), (batch_size,)) sim_tgt_out_idx = self.permutation_index[random_indices] + 2 if self.parameters.simulation_distance_penalty is not None: sim_distance = self.permutation_distance[random_indices] else: sim_distance = None with torch.no_grad(): # format data according to the new ordering training_input = self._simulated_training_input( training_input, sim_tgt_out_idx, sim_distance, self.device ) # data in the results_dict: # { # "per_seq_probs": np.exp(log_probs), # "advantage": advantage, # "obj_rl_loss": obj_rl_loss, # "ips_rl_loss": ips_rl_loss, # "baseline_loss": baseline_loss, # } results_dict = self.trainer.train( rlt.PreprocessedTrainingBatch( training_input=training_input, extras=training_batch.extras ) ) # pyre-fixme[16]: `Seq2SlateSimulationTrainer` has no attribute # `notify_observers`. self.notify_observers( pg_loss=torch.tensor(results_dict["ips_rl_loss"]).reshape(1), train_baseline_loss=torch.tensor(results_dict["baseline_loss"]).reshape(1), train_log_probs=torch.FloatTensor(np.log(results_dict["per_seq_probs"])), ) return results_dict
def sample_memories(self, batch_size, use_gpu=False, batch_first=False) -> rlt.PreprocessedTrainingBatch: """ :param batch_size: number of samples to return :param use_gpu: whether to put samples on gpu :param batch_first: If True, the first dimension of data is batch_size. If False (default), the first dimension is SEQ_LEN. Therefore, state's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default, MDN-RNN consumes data with SEQ_LEN as the first dimension. """ sample_indices = np.random.randint(self.memory_size, size=batch_size) device = ( torch.device("cuda") if use_gpu else torch.device("cpu") # type: ignore ) # state/next state shape: batch_size x seq_len x state_dim # action shape: batch_size x seq_len x action_dim # reward/not_terminal shape: batch_size x seq_len state, action, next_state, reward, not_terminal = map( lambda x: stack(x).float().to(device), zip(*self.deque_sample(sample_indices)), ) if not batch_first: state, action, next_state, reward, not_terminal = transpose( state, action, next_state, reward, not_terminal) training_input = rlt.PreprocessedMemoryNetworkInput( state=rlt.PreprocessedFeatureVector(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=action, next_state=rlt.PreprocessedFeatureVector( float_features=next_state), not_terminal=not_terminal, step=None, ) return rlt.PreprocessedTrainingBatch(training_input=training_input, extras=None)
def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) batch_size = obs.shape[0] obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action).float() next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action).to(torch.float32) reward = torch.tensor(reward).unsqueeze(1) not_terminal = 1 - torch.tensor(terminal).unsqueeze(1).to(torch.uint8) possible_actions_mask = torch.ones_like(action).to(torch.bool) tiled_next_state = torch.repeat_interleave(next_obs, repeats=num_actions, axis=0) possible_next_actions = torch.eye(num_actions).repeat(batch_size, 1) possible_next_actions_mask = not_terminal.repeat(1, num_actions).to( torch.bool) return rlt.PreprocessedTrainingBatch( rlt.PreprocessedParametricDqnInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=rlt.PreprocessedFeatureVector(float_features=action), next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=rlt.PreprocessedFeatureVector( float_features=next_action), possible_actions=None, possible_actions_mask=possible_actions_mask, possible_next_actions=rlt.PreprocessedFeatureVector( float_features=possible_next_actions), possible_next_actions_mask=possible_next_actions_mask, tiled_next_state=rlt.PreprocessedFeatureVector( float_features=tiled_next_state), reward=reward, not_terminal=not_terminal, step=None, time_diff=None, ), extras=rlt.ExtraData(), )
def as_cem_training_batch(self, batch_first=False): """ Generate one-step samples needed by CEM trainer. The samples will be used to train an ensemble of world models used by CEM. If batch_first = True: state/next state shape: batch_size x 1 x state_dim action shape: batch_size x 1 x action_dim reward/terminal shape: batch_size x 1 else (default): state/next state shape: 1 x batch_size x state_dim action shape: 1 x batch_size x action_dim reward/terminal shape: 1 x batch_size """ if batch_first: seq_len_dim = 1 reward, not_terminal = self.rewards, self.not_terminal else: seq_len_dim = 0 reward, not_terminal = transpose(self.rewards, self.not_terminal) training_input = rlt.PreprocessedMemoryNetworkInput( state=rlt.PreprocessedFeatureVector( float_features=self.states.unsqueeze(seq_len_dim)), action=self.actions.unsqueeze(seq_len_dim), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states.unsqueeze(seq_len_dim)), reward=reward, not_terminal=not_terminal, step=self.step, time_diff=self.time_diffs, ) return rlt.PreprocessedTrainingBatch( training_input=training_input, extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def test_seq2slate_eval_data_page(self): """ Create 3 slate ranking logs and evaluate using Direct Method, Inverse Propensity Scores, and Doubly Robust. The logs are as follows: state: [1, 0, 0], [0, 1, 0], [0, 0, 1] indices in logged slates: [3, 2], [3, 2], [3, 2] model output indices: [2, 3], [3, 2], [2, 3] logged reward: 4, 5, 7 logged propensities: 0.2, 0.5, 0.4 predicted rewards on logged slates: 2, 4, 6 predicted rewards on model outputted slates: 1, 4, 5 predicted propensities: 0.4, 0.3, 0.7 When eval_greedy=True: Direct Method uses the predicted rewards on model outputted slates. Thus the result is expected to be (1 + 4 + 5) / 3 Inverse Propensity Scores would scale the reward by 1.0 / logged propensities whenever the model output slate matches with the logged slate. Since only the second log matches with the model output, the IPS result is expected to be 5 / 0.5 / 3 Doubly Robust is the sum of the direct method result and propensity-scaled reward difference; the latter is defined as: 1.0 / logged_propensities * (logged reward - predicted reward on logged slate) * Indicator(model slate == logged slate) Since only the second logged slate matches with the model outputted slate, the DR result is expected to be (1 + 4 + 5) / 3 + 1.0 / 0.5 * (5 - 4) / 3 When eval_greedy=False: Only Inverse Propensity Scores would be accurate. Because it would be too expensive to compute all possible slates' propensities and predicted rewards for Direct Method. The expected IPS = (0.4 / 0.2 * 4 + 0.3 / 0.5 * 5 + 0.7 / 0.4 * 7) / 3 """ batch_size = 3 state_dim = 3 src_seq_len = 2 tgt_seq_len = 2 candidate_dim = 2 reward_net = FakeSeq2SlateRewardNetwork() seq2slate_net = FakeSeq2SlateTransformerNet() src_seq = torch.eye(candidate_dim).repeat(batch_size, 1, 1) tgt_out_idx = torch.LongTensor([[3, 2], [3, 2], [3, 2]]) tgt_out_seq = src_seq[ torch.arange(batch_size).repeat_interleave(tgt_seq_len), tgt_out_idx.flatten() - 2, ].reshape(batch_size, tgt_seq_len, candidate_dim) ptb = rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedRankingInput( state=rlt.FeatureData(float_features=torch.eye(state_dim)), src_seq=rlt.FeatureData(float_features=src_seq), tgt_out_seq=rlt.FeatureData(float_features=tgt_out_seq), src_src_mask=torch.ones(batch_size, src_seq_len, src_seq_len), tgt_out_idx=tgt_out_idx, tgt_out_probs=torch.tensor([0.2, 0.5, 0.4]), slate_reward=torch.tensor([4.0, 5.0, 7.0]), ), extras=rlt.ExtraData( sequence_number=torch.tensor([0, 0, 0]), mdp_id=np.array(["0", "1", "2"]), ), ) edp = EvaluationDataPage.create_from_tensors_seq2slate( seq2slate_net, reward_net, ptb.training_input, eval_greedy=True) logger.info( "---------- Start evaluating eval_greedy=True -----------------") doubly_robust_estimator = OPEstimatorAdapter(DoublyRobustEstimator()) dm_estimator = OPEstimatorAdapter(DMEstimator()) ips_estimator = OPEstimatorAdapter(IPSEstimator()) switch_estimator = OPEstimatorAdapter(SwitchEstimator()) switch_dr_estimator = OPEstimatorAdapter(SwitchDREstimator()) doubly_robust = doubly_robust_estimator.estimate(edp) inverse_propensity = ips_estimator.estimate(edp) direct_method = dm_estimator.estimate(edp) # Verify that Switch with low exponent is equivalent to IPS switch_ips = switch_estimator.estimate(edp, exp_base=1) # Verify that Switch with no candidates is equivalent to DM switch_dm = switch_estimator.estimate(edp, candidates=0) # Verify that SwitchDR with low exponent is equivalent to DR switch_dr_dr = switch_dr_estimator.estimate(edp, exp_base=1) # Verify that SwitchDR with no candidates is equivalent to DM switch_dr_dm = switch_dr_estimator.estimate(edp, candidates=0) logger.info(f"{direct_method}, {inverse_propensity}, {doubly_robust}") avg_logged_reward = (4 + 5 + 7) / 3 self.assertAlmostEqual(direct_method.raw, (1 + 4 + 5) / 3, delta=1e-6) self.assertAlmostEqual(direct_method.normalized, direct_method.raw / avg_logged_reward, delta=1e-6) self.assertAlmostEqual(inverse_propensity.raw, 5 / 0.5 / 3, delta=1e-6) self.assertAlmostEqual( inverse_propensity.normalized, inverse_propensity.raw / avg_logged_reward, delta=1e-6, ) self.assertAlmostEqual(doubly_robust.raw, direct_method.raw + 1 / 0.5 * (5 - 4) / 3, delta=1e-6) self.assertAlmostEqual(doubly_robust.normalized, doubly_robust.raw / avg_logged_reward, delta=1e-6) self.assertAlmostEqual(switch_ips.raw, inverse_propensity.raw, delta=1e-6) self.assertAlmostEqual(switch_dm.raw, direct_method.raw, delta=1e-6) self.assertAlmostEqual(switch_dr_dr.raw, doubly_robust.raw, delta=1e-6) self.assertAlmostEqual(switch_dr_dm.raw, direct_method.raw, delta=1e-6) logger.info( "---------- Finish evaluating eval_greedy=True -----------------") logger.info( "---------- Start evaluating eval_greedy=False -----------------") edp = EvaluationDataPage.create_from_tensors_seq2slate( seq2slate_net, reward_net, ptb.training_input, eval_greedy=False) doubly_robust_estimator = OPEstimatorAdapter(DoublyRobustEstimator()) dm_estimator = OPEstimatorAdapter(DMEstimator()) ips_estimator = OPEstimatorAdapter(IPSEstimator()) doubly_robust = doubly_robust_estimator.estimate(edp) inverse_propensity = ips_estimator.estimate(edp) direct_method = dm_estimator.estimate(edp) self.assertAlmostEqual( inverse_propensity.raw, (0.4 / 0.2 * 4 + 0.3 / 0.5 * 5 + 0.7 / 0.4 * 7) / 3, delta=1e-6, ) self.assertAlmostEqual( inverse_propensity.normalized, inverse_propensity.raw / avg_logged_reward, delta=1e-6, ) logger.info( "---------- Finish evaluating eval_greedy=False -----------------")
def _test_seq2slate_trainer_on_policy(self, policy_gradient_interval, output_arch, device): batch_size = 32 state_dim = 2 candidate_num = 15 candidate_dim = 4 hidden_size = 16 learning_rate = 1.0 on_policy = True rank_seed = 111 seq2slate_params = Seq2SlateParameters(on_policy=on_policy) seq2slate_net = create_seq2slate_transformer(state_dim, candidate_num, candidate_dim, hidden_size, output_arch, device) seq2slate_net_copy = copy.deepcopy(seq2slate_net) seq2slate_net_copy_copy = copy.deepcopy(seq2slate_net) trainer = create_trainer( seq2slate_net, batch_size, learning_rate, device, seq2slate_params, policy_gradient_interval, ) batch = create_on_policy_batch( seq2slate_net, batch_size, state_dim, candidate_num, candidate_dim, rank_seed, device, ) for _ in range(policy_gradient_interval): trainer.train(rlt.PreprocessedTrainingBatch(training_input=batch)) # manual compute gradient torch.manual_seed(rank_seed) rank_output = seq2slate_net_copy(batch, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=candidate_num, greedy=False) loss = -(torch.mean( torch.log(rank_output.ranked_per_seq_probs) * batch.slate_reward)) loss.backward() self.assert_correct_gradient(seq2slate_net_copy, seq2slate_net, policy_gradient_interval, learning_rate) # another way to compute gradient manually torch.manual_seed(rank_seed) ranked_per_seq_probs = seq2slate_net_copy_copy( batch, mode=Seq2SlateMode.RANK_MODE, tgt_seq_len=candidate_num, greedy=False).ranked_per_seq_probs loss = -(torch.mean( ranked_per_seq_probs / ranked_per_seq_probs.detach() * batch.slate_reward)) loss.backward() self.assert_correct_gradient( seq2slate_net_copy_copy, seq2slate_net, policy_gradient_interval, learning_rate, )
def _test_seq2slate_on_policy_tsp( self, model_str, batch_size, epochs, candidate_num, num_batches, hidden_size, diverse_input, learning_rate, expect_reward_threshold, device, ): candidate_dim = 2 eval_sample_size = 1 batch_list = [ create_batch( batch_size, candidate_num, candidate_dim, device, diverse_input=diverse_input, ) for _ in range(num_batches) ] if diverse_input: test_batch = create_batch( batch_size, candidate_num, candidate_dim, device, diverse_input=diverse_input, ) else: test_batch = batch_list[0] best_test_possible_reward = compute_best_reward( test_batch.src_seq.float_features) if model_str == MODEL_TRANSFORMER: seq2slate_net = create_seq2slate_transformer( candidate_num, candidate_dim, hidden_size, Seq2SlateOutputArch.AUTOREGRESSIVE, 1.0, device, ) else: raise NotImplementedError(f"unknown model type {model_str}") trainer = create_trainer(seq2slate_net, batch_size, learning_rate, device, on_policy=True) for e in range(epochs): for batch in batch_list: model_propensity, model_action, reward = rank_on_policy_and_eval( seq2slate_net, batch, candidate_num, greedy=False) on_policy_batch = rlt.PreprocessedRankingInput.from_input( state=batch.state.float_features, candidates=batch.src_seq.float_features, device=device, action=model_action, logged_propensities=model_propensity, slate_reward=-reward, # negate because we want to minimize ) trainer.train( rlt.PreprocessedTrainingBatch( training_input=on_policy_batch)) logger.info( f"Epoch {e} mean on_policy reward: {torch.mean(reward)}") logger.info( f"Epoch {e} mean model_propensity: {torch.mean(model_propensity)}" ) # evaluation best_test_reward = torch.full((batch_size, ), 1e9).to(device) for _ in range(eval_sample_size): _, _, reward = rank_on_policy_and_eval(seq2slate_net, test_batch, candidate_num, greedy=True) best_test_reward = torch.where(reward < best_test_reward, reward, best_test_reward) logger.info(f"Test mean reward: {torch.mean(best_test_reward)}, " f"best possible reward {best_test_possible_reward}") if (torch.mean(best_test_reward) < best_test_possible_reward * expect_reward_threshold): return raise AssertionError( "Test failed because it did not reach expected test reward")