def input_prototype(self): return rlt.PreprocessedStateAction( state=rlt.PreprocessedFeatureVector( float_features=torch.randn(1, 1, self.state_dim)), action=rlt.PreprocessedFeatureVector( float_features=torch.randn(1, 1, self.action_dim)), )
def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action).float() reward = torch.tensor(reward).unsqueeze(1) next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action) not_terinal = 1.0 - torch.tensor(terminal).unsqueeze(1).float() idxs = torch.tensor(idxs) possible_actions_mask = torch.tensor(possible_actions_mask).float() log_prob = torch.tensor(log_prob) return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedPolicyNetworkInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=rlt.PreprocessedFeatureVector(float_features=action), next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=rlt.PreprocessedFeatureVector( float_features=next_action), reward=reward, not_terminal=not_terinal, step=None, time_diff=None, ), extras=rlt.ExtraData(), )
def sample_memories(self, batch_size, use_gpu=False, batch_first=False): """ :param batch_size: number of samples to return :param use_gpu: whether to put samples on gpu :param batch_first: If True, the first dimension of data is batch_size. If False (default), the first dimension is SEQ_LEN. Therefore, state's shape is SEQ_LEN x BATCH_SIZE x STATE_DIM, for example. By default, MDN-RNN consumes data with SEQ_LEN as the first dimension. """ sample_indices = np.random.randint(self.memory_size, size=batch_size) device = torch.device("cuda") if use_gpu else torch.device("cpu") # state/next state shape: batch_size x seq_len x state_dim # action shape: batch_size x seq_len x action_dim # reward/not_terminal shape: batch_size x seq_len state, action, next_state, reward, not_terminal = map( lambda x: stack(x).float().to(device), zip(*self.deque_sample(sample_indices)), ) if not batch_first: state, action, next_state, reward, not_terminal = transpose( state, action, next_state, reward, not_terminal) training_input = rlt.PreprocessedMemoryNetworkInput( state=rlt.PreprocessedFeatureVector(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=action, next_state=rlt.PreprocessedFeatureVector( float_features=next_state), not_terminal=not_terminal, step=None, ) return rlt.PreprocessedTrainingBatch(training_input=training_input, extras=None)
def forward( self, state: torch.Tensor, src_seq: torch.Tensor, src_src_mask: torch.Tensor, slate_reward: torch.Tensor, tgt_out_idx: torch.Tensor, ) -> torch.Tensor: return self.model( rlt.PreprocessedRankingInput( state=rlt.PreprocessedFeatureVector(float_features=state), src_seq=rlt.PreprocessedFeatureVector(float_features=src_seq), src_src_mask=src_src_mask, slate_reward=slate_reward, tgt_out_idx=tgt_out_idx, )).predicted_reward
def acc_rewards_of_one_solution( self, init_state: torch.Tensor, solution: torch.Tensor, solution_idx: int ): """ ensemble_pop_size trajectories will be sampled to evaluate a CEM solution. Each trajectory is generated by one world model :param init_state: its shape is (state_dim, ) :param solution: its shape is (plan_horizon_length, action_dim) :param solution_idx: the index of the solution :return reward: Reward of each of ensemble_pop_size trajectories """ reward_matrix = np.zeros((self.ensemble_pop_size, self.plan_horizon_length)) for i in range(self.ensemble_pop_size): state = init_state mem_net_idx = np.random.randint(0, len(self.mem_net_list)) for j in range(self.plan_horizon_length): # world_model_input.state shape: # (1, 1, state_dim) # world_model_input.action shape: # (1, 1, action_dim) world_model_input = rlt.PreprocessedStateAction( state=rlt.PreprocessedFeatureVector( float_features=state.reshape((1, 1, self.state_dim)) ), action=rlt.PreprocessedFeatureVector( float_features=solution[j, :].reshape((1, 1, self.action_dim)) ), ) reward, next_state, not_terminal, not_terminal_prob = self.sample_reward_next_state_terminal( world_model_input, self.mem_net_list[mem_net_idx] ) reward_matrix[i, j] = reward * (self.gamma ** j) if not not_terminal: logger.debug( f"Solution {solution_idx}: predict terminal at step {j}" f" with prob. {1.0 - not_terminal_prob}" ) if not not_terminal: break state = next_state return np.sum(reward_matrix, axis=1)
def as_slate_q_training_batch(self): batch_size, state_dim = self.states.shape action_dim = self.actions.shape[1] return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedSlateQInput( state=rlt.PreprocessedFeatureVector( float_features=self.states), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states), tiled_state=rlt.PreprocessedTiledFeatureVector( float_features=self. possible_actions_state_concat[:, :state_dim].view( batch_size, -1, state_dim)), tiled_next_state=rlt.PreprocessedTiledFeatureVector( float_features=self. possible_next_actions_state_concat[:, :state_dim].view( batch_size, -1, state_dim)), action=rlt.PreprocessedSlateFeatureVector( float_features=self. possible_actions_state_concat[:, state_dim:].view( batch_size, -1, action_dim), item_mask=self.possible_actions_mask, item_probability=self.propensities, ), next_action=rlt.PreprocessedSlateFeatureVector( float_features=self. possible_next_actions_state_concat[:, state_dim:].view( batch_size, -1, action_dim), item_mask=self.possible_next_actions_mask, item_probability=self.next_propensities, ), reward=self.rewards, reward_mask=self.rewards_mask, time_diff=self.time_diffs, step=self.step, not_terminal=self.not_terminal, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def as_policy_network_training_batch(self): return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedPolicyNetworkInput( state=rlt.PreprocessedFeatureVector(float_features=self.states), action=rlt.PreprocessedFeatureVector(float_features=self.actions), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states ), next_action=rlt.PreprocessedFeatureVector( float_features=self.next_actions ), reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def score(preprocessed_obs: rlt.PreprocessedState) -> GaussianSamplerScore: actor_network.eval() # TODO(kaiwenw) currently actor network demands a batched input. # should we make it single? state = rlt.PreprocessedFeatureVector( float_features=preprocessed_obs.state.float_features.unsqueeze(0)) loc, scale_log = actor_network._get_loc_and_scale_log(state) actor_network.train() return GaussianSamplerScore(loc=loc, scale_log=scale_log)
def input_prototype(self): return rlt.PreprocessedState(state=rlt.PreprocessedFeatureVector( float_features=torch.randn(1, self.state_dim), id_list_features={ k: ( torch.zeros(1, dtype=torch.long), torch.ones(1, dtype=torch.long), ) for k in self.embedding_bags }, ))
def __call__(self, batch: rlt.RawTrainingBatch) -> rlt.PreprocessedTrainingBatch: preprocessed_batch = super().__call__(batch) training_input = preprocessed_batch.training_input assert isinstance(training_input, rlt.PreprocessedMemoryNetworkInput) preprocessed_batch = preprocessed_batch._replace( training_input=training_input._replace( state=rlt.PreprocessedFeatureVector( float_features=training_input.state.float_features.reshape( -1, self.seq_len, self.state_dim)), action=training_input.action.reshape(-1, self.seq_len, self.action_dim), next_state=rlt.PreprocessedFeatureVector( float_features=training_input.next_state.float_features. reshape(-1, self.seq_len, self.state_dim)), reward=training_input.reward.reshape(-1, self.seq_len), not_terminal=preprocessed_batch.training_input.not_terminal. reshape(-1, self.seq_len), )) return preprocessed_batch
def as_cem_training_batch(self, batch_first=False): """ Generate one-step samples needed by CEM trainer. The samples will be used to train an ensemble of world models used by CEM. If batch_first = True: state/next state shape: batch_size x 1 x state_dim action shape: batch_size x 1 x action_dim reward/terminal shape: batch_size x 1 else (default): state/next state shape: 1 x batch_size x state_dim action shape: 1 x batch_size x action_dim reward/terminal shape: 1 x batch_size """ if batch_first: seq_len_dim = 1 reward, not_terminal = self.rewards, self.not_terminal else: seq_len_dim = 0 reward, not_terminal = transpose(self.rewards, self.not_terminal) training_input = rlt.PreprocessedMemoryNetworkInput( state=rlt.PreprocessedFeatureVector( float_features=self.states.unsqueeze(seq_len_dim)), action=self.actions.unsqueeze(seq_len_dim), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states.unsqueeze(seq_len_dim)), reward=reward, not_terminal=not_terminal, step=self.step, time_diff=self.time_diffs, ) return rlt.PreprocessedTrainingBatch( training_input=training_input, extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def internal_prediction( self, state: torch.Tensor ) -> Union[rlt.SacPolicyActionSet, rlt.DqnPolicyActionSet]: """ Only used by Gym. Return the predicted next action """ input = rlt.PreprocessedState(state=rlt.PreprocessedFeatureVector( float_features=state)) output = self.cem_planner_network(input) if not self.cem_planner_network.discrete_action: return rlt.SacPolicyActionSet(greedy=output, greedy_propensity=1.0) return rlt.DqnPolicyActionSet(greedy=output[0])
def as_parametric_maxq_training_batch(self): state_dim = self.states.shape[1] return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedParametricDqnInput( state=rlt.PreprocessedFeatureVector(float_features=self.states), action=rlt.PreprocessedFeatureVector(float_features=self.actions), next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states ), next_action=rlt.PreprocessedFeatureVector( float_features=self.next_actions ), tiled_next_state=rlt.PreprocessedFeatureVector( float_features=self.possible_next_actions_state_concat[ :, :state_dim ] ), possible_actions=None, possible_actions_mask=self.possible_actions_mask, possible_next_actions=rlt.PreprocessedFeatureVector( float_features=self.possible_next_actions_state_concat[ :, state_dim: ] ), possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData(), )
def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action) reward = torch.tensor(reward).unsqueeze(1) next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action) not_terminal = 1.0 - torch.tensor(terminal).unsqueeze(1).float() possible_actions_mask = torch.tensor(possible_actions_mask) next_possible_actions_mask = not_terminal.repeat(1, num_actions) log_prob = torch.tensor(log_prob) assert ( action.size(1) == num_actions ), f"action size(1) is {action.size(1)} while num_actions is {num_actions}" return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedDiscreteDqnInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=action, next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=next_action, possible_actions_mask=possible_actions_mask, possible_next_actions_mask=next_possible_actions_mask, reward=reward, not_terminal=not_terminal, step=None, time_diff=None, ), extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=log_prob.exp(), max_num_actions=None, metrics=None, ), )
def test_discrete_wrapper_with_id_list(self): state_normalization_parameters = {i: _cont_norm() for i in range(1, 5)} state_preprocessor = Preprocessor(state_normalization_parameters, False) action_dim = 2 state_feature_config = rlt.ModelFeatureConfig( float_feature_infos=[ rlt.FloatFeatureInfo(name=str(i), feature_id=i) for i in range(1, 5) ], id_list_feature_configs=[ rlt.IdListFeatureConfig(name="A", feature_id=10, id_mapping_name="A_mapping") ], id_mapping_config={"A_mapping": rlt.IdMapping(ids=[0, 1, 2])}, ) dqn = FullyConnectedDQNWithEmbedding( state_dim=len(state_normalization_parameters), action_dim=action_dim, sizes=[16], activations=["relu"], model_feature_config=state_feature_config, embedding_dim=8, ) dqn_with_preprocessor = DiscreteDqnWithPreprocessorWithIdList( dqn, state_preprocessor, state_feature_config) action_names = ["L", "R"] wrapper = DiscreteDqnPredictorWrapperWithIdList( dqn_with_preprocessor, action_names, state_feature_config) input_prototype = dqn_with_preprocessor.input_prototype() output_action_names, q_values = wrapper(*input_prototype) self.assertEqual(action_names, output_action_names) self.assertEqual(q_values.shape, (1, 2)) feature_id_to_name = { config.feature_id: config.name for config in state_feature_config.id_list_feature_configs } state_id_list_features = { feature_id_to_name[k]: v for k, v in input_prototype[1].items() } expected_output = dqn( rlt.PreprocessedState(state=rlt.PreprocessedFeatureVector( float_features=state_preprocessor(*input_prototype[0]), id_list_features=state_id_list_features, ))).q_values self.assertTrue((expected_output == q_values).all())
def as_discrete_maxq_training_batch(self): return rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedDiscreteDqnInput( state=rlt.PreprocessedFeatureVector(float_features=self.states), action=self.actions, next_state=rlt.PreprocessedFeatureVector( float_features=self.next_states ), next_action=self.next_actions, possible_actions_mask=self.possible_actions_mask, possible_next_actions_mask=self.possible_next_actions_mask, reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, ), extras=rlt.ExtraData( mdp_id=self.mdp_ids, sequence_number=self.sequence_numbers, action_probability=self.propensities, max_num_actions=self.max_num_actions, metrics=self.metrics, ), )
def forward( self, state_with_presence: Tuple[torch.Tensor, torch.Tensor], state_id_list_features: Dict[int, Tuple[torch.Tensor, torch.Tensor]], ): preprocessed_state = self.state_preprocessor(state_with_presence[0], state_with_presence[1]) id_list_features = { id_list_feature_config.name: state_id_list_features[id_list_feature_config.feature_id] for id_list_feature_config in self.id_list_feature_configs } state_feature_vector = rlt.PreprocessedState( state=rlt.PreprocessedFeatureVector( float_features=preprocessed_state, id_list_features=id_list_features)) q_values = self.model(state_feature_vector).q_values return q_values
def preprocess_batch(train_batch: Any) -> rlt.PreprocessedTrainingBatch: obs, action, reward, next_obs, next_action, next_reward, terminal, idxs, possible_actions_mask, log_prob = ( train_batch) batch_size = obs.shape[0] obs = torch.tensor(obs).squeeze(2) action = torch.tensor(action).float() next_obs = torch.tensor(next_obs).squeeze(2) next_action = torch.tensor(next_action).to(torch.float32) reward = torch.tensor(reward).unsqueeze(1) not_terminal = 1 - torch.tensor(terminal).unsqueeze(1).to(torch.uint8) possible_actions_mask = torch.ones_like(action).to(torch.bool) tiled_next_state = torch.repeat_interleave(next_obs, repeats=num_actions, axis=0) possible_next_actions = torch.eye(num_actions).repeat(batch_size, 1) possible_next_actions_mask = not_terminal.repeat(1, num_actions).to( torch.bool) return rlt.PreprocessedTrainingBatch( rlt.PreprocessedParametricDqnInput( state=rlt.PreprocessedFeatureVector(float_features=obs), action=rlt.PreprocessedFeatureVector(float_features=action), next_state=rlt.PreprocessedFeatureVector( float_features=next_obs), next_action=rlt.PreprocessedFeatureVector( float_features=next_action), possible_actions=None, possible_actions_mask=possible_actions_mask, possible_next_actions=rlt.PreprocessedFeatureVector( float_features=possible_next_actions), possible_next_actions_mask=possible_next_actions_mask, tiled_next_state=rlt.PreprocessedFeatureVector( float_features=tiled_next_state), reward=reward, not_terminal=not_terminal, step=None, time_diff=None, ), extras=rlt.ExtraData(), )
def get_loss( self, training_batch: rlt.PreprocessedTrainingBatch, state_dim: Optional[int] = None, batch_first: bool = False, ): """ Compute losses: GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2) + MSE(reward, predicted_reward) + BCE(not_terminal, logit_not_terminal) The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearly with STATE_DIM, the feature size of states. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :param training_batch: training_batch.learning_input has these fields: - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor - reward: (BATCH_SIZE, SEQ_LEN) torch tensor - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor the first two dimensions may be swapped depending on batch_first :param state_dim: the dimension of states. If provided, use it to normalize gmm loss :param batch_first: whether data's first dimension represents batch size. If FALSE, state, action, reward, not-terminal, and next_state's first two dimensions are SEQ_LEN and BATCH_SIZE. :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ learning_input = training_batch.training_input assert isinstance(learning_input, rlt.PreprocessedMemoryNetworkInput) # mdnrnn's input should have seq_len as the first dimension if batch_first: state, action, next_state, reward, not_terminal = transpose( learning_input.state.float_features, learning_input.action, learning_input.next_state.float_features, learning_input.reward, learning_input.not_terminal, # type: ignore ) learning_input = rlt.PreprocessedMemoryNetworkInput( # type: ignore state=rlt.PreprocessedFeatureVector(float_features=state), reward=reward, time_diff=torch.ones_like(reward).float(), action=action, not_terminal=not_terminal, next_state=rlt.PreprocessedFeatureVector( float_features=next_state), step=None, ) mdnrnn_input = rlt.PreprocessedStateAction( state=learning_input.state, # type: ignore action=rlt.PreprocessedFeatureVector( float_features=learning_input.action), # type: ignore ) mdnrnn_output = self.mdnrnn(mdnrnn_input) mus, sigmas, logpi, rs, nts = ( mdnrnn_output.mus, mdnrnn_output.sigmas, mdnrnn_output.logpi, mdnrnn_output.reward, mdnrnn_output.not_terminal, ) next_state = learning_input.next_state.float_features not_terminal = learning_input.not_terminal # type: ignore reward = learning_input.reward if self.params.fit_only_one_next_step: next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple( map( lambda x: x[-1:], (next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs), )) gmm = (gmm_loss(next_state, mus, sigmas, logpi) * self.params.next_state_loss_weight) bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) * self.params.not_terminal_loss_weight) mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight if state_dim is not None: loss = gmm / (state_dim + 2) + bce + mse else: loss = gmm + bce + mse return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
def create_from_tensors_parametric_dqn( cls, trainer: ParametricDQNTrainer, mdp_ids: np.ndarray, sequence_numbers: torch.Tensor, states: rlt.PreprocessedFeatureVector, actions: rlt.PreprocessedFeatureVector, propensities: torch.Tensor, rewards: torch.Tensor, possible_actions_mask: torch.Tensor, possible_actions: rlt.PreprocessedFeatureVector, max_num_actions: int, metrics: Optional[torch.Tensor] = None, ): old_q_train_state = trainer.q_network.training old_reward_train_state = trainer.reward_network.training trainer.q_network.train(False) trainer.reward_network.train(False) state_action_pairs = rlt.PreprocessedStateAction(state=states, action=actions) tiled_state = states.float_features.repeat(1, max_num_actions).reshape( -1, states.float_features.shape[1] ) assert possible_actions is not None # Get Q-value of action taken possible_actions_state_concat = rlt.PreprocessedStateAction( state=rlt.PreprocessedFeatureVector(float_features=tiled_state), action=possible_actions, ) # FIXME: model_values, model_values_for_logged_action, and model_metrics_values # should be calculated using q_network_cpe (as in discrete dqn). # q_network_cpe has not been added in parametric dqn yet. model_values = trainer.q_network( possible_actions_state_concat ).q_value # type: ignore optimal_q_values, _ = trainer.get_detached_q_values( possible_actions_state_concat.state, possible_actions_state_concat.action ) eval_action_idxs = None assert ( model_values.shape[1] == 1 and model_values.shape[0] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1] ), ( "Invalid shapes: " + str(model_values.shape) + " != " + str(possible_actions_mask.shape) ) model_values = model_values.reshape(possible_actions_mask.shape) optimal_q_values = optimal_q_values.reshape(possible_actions_mask.shape) model_propensities = masked_softmax( optimal_q_values, possible_actions_mask, trainer.rl_temperature ) rewards_and_metric_rewards = trainer.reward_network( possible_actions_state_concat ).q_value # type: ignore model_rewards = rewards_and_metric_rewards[:, :1] assert ( model_rewards.shape[0] * model_rewards.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1] ), ( "Invalid shapes: " + str(model_rewards.shape) + " != " + str(possible_actions_mask.shape) ) model_rewards = model_rewards.reshape(possible_actions_mask.shape) model_metrics = rewards_and_metric_rewards[:, 1:] model_metrics = model_metrics.reshape(possible_actions_mask.shape[0], -1) model_values_for_logged_action = trainer.q_network(state_action_pairs).q_value model_rewards_and_metrics_for_logged_action = trainer.reward_network( state_action_pairs ).q_value model_rewards_for_logged_action = model_rewards_and_metrics_for_logged_action[ :, :1 ] action_dim = possible_actions.float_features.shape[1] action_mask = torch.all( possible_actions.float_features.view(-1, max_num_actions, action_dim) == actions.float_features.unsqueeze(dim=1), dim=2, ).float() assert torch.all(action_mask.sum(dim=1) == 1) num_metrics = model_metrics.shape[1] // max_num_actions model_metrics_values = None model_metrics_for_logged_action = None model_metrics_values_for_logged_action = None if num_metrics > 0: # FIXME: calculate model_metrics_values when q_network_cpe is added # to parametric dqn model_metrics_values = model_values.repeat(1, num_metrics) trainer.q_network.train(old_q_train_state) # type: ignore trainer.reward_network.train(old_reward_train_state) # type: ignore return cls( mdp_id=mdp_ids, sequence_number=sequence_numbers, logged_propensities=propensities, logged_rewards=rewards, action_mask=action_mask, model_rewards=model_rewards, model_rewards_for_logged_action=model_rewards_for_logged_action, model_values=model_values, model_values_for_logged_action=model_values_for_logged_action, model_metrics_values=model_metrics_values, model_metrics_values_for_logged_action=model_metrics_values_for_logged_action, model_propensities=model_propensities, logged_metrics=metrics, model_metrics=model_metrics, model_metrics_for_logged_action=model_metrics_for_logged_action, # Will compute later logged_values=None, logged_metrics_values=None, possible_actions_mask=possible_actions_mask, optimal_q_values=optimal_q_values, eval_action_idxs=eval_action_idxs, )
def train(self, training_batch) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ if hasattr(training_batch, "as_policy_network_training_batch"): training_batch = training_batch.as_policy_network_training_batch() learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state action = learning_input.action reward = learning_input.reward discount = torch.full_like(reward, self.gamma) not_done_mask = learning_input.not_terminal if self._should_scale_action_in_train(): action = action._replace( float_features=rescale_torch_tensor( action.float_features, new_min=self.min_action_range_tensor_training, new_max=self.max_action_range_tensor_training, prev_min=self.min_action_range_tensor_serving, prev_max=self.max_action_range_tensor_serving, ) ) with torch.enable_grad(): # # First, optimize Q networks; minimizing MSE between # Q(s, a) & r + discount * V'(next_s) # current_state_action = rlt.PreprocessedStateAction( state=state, action=action ) q1_value = self.q1_network(current_state_action).q_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value actor_output = self.actor_network(rlt.PreprocessedState(state=state)) # Optimize Alpha if self.alpha_optimizer is not None: alpha_loss = -( self.log_alpha * (actor_output.log_prob + self.target_entropy).detach() ).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.entropy_temperature = self.log_alpha.exp() with torch.no_grad(): if self.value_network is not None: next_state_value = self.value_network_target( learning_input.next_state.float_features ) else: next_state_actor_output = self.actor_network( rlt.PreprocessedState(state=learning_input.next_state) ) next_state_actor_action = rlt.PreprocessedStateAction( state=learning_input.next_state, action=rlt.PreprocessedFeatureVector( float_features=next_state_actor_output.action ), ) next_state_value = self.q1_network_target( next_state_actor_action ).q_value if self.q2_network is not None: target_q2_value = self.q2_network_target( next_state_actor_action ).q_value next_state_value = torch.min(next_state_value, target_q2_value) log_prob_a = self.actor_network.get_log_prob( learning_input.next_state, next_state_actor_output.action ) log_prob_a = log_prob_a.clamp(-20.0, 20.0) next_state_value -= self.entropy_temperature * log_prob_a target_q_value = ( reward + discount * next_state_value * not_done_mask.float() ) q1_loss = F.mse_loss(q1_value, target_q_value) q1_loss.backward() self._maybe_run_optimizer( self.q1_network_optimizer, self.minibatches_per_step ) if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) q2_loss.backward() self._maybe_run_optimizer( self.q2_network_optimizer, self.minibatches_per_step ) # # Second, optimize the actor; minimizing KL-divergence between action propensity # & softmax of value. Due to reparameterization trick, it ends up being # log_prob(actor_action) - Q(s, actor_action) # state_actor_action = rlt.PreprocessedStateAction( state=state, action=rlt.PreprocessedFeatureVector( float_features=actor_output.action ), ) q1_actor_value = self.q1_network(state_actor_action).q_value min_q_actor_value = q1_actor_value if self.q2_network: q2_actor_value = self.q2_network(state_actor_action).q_value min_q_actor_value = torch.min(q1_actor_value, q2_actor_value) actor_loss = ( self.entropy_temperature * actor_output.log_prob - min_q_actor_value ) # Do this in 2 steps so we can log histogram of actor loss actor_loss_mean = actor_loss.mean() actor_loss_mean.backward() self._maybe_run_optimizer( self.actor_network_optimizer, self.minibatches_per_step ) # # Lastly, if applicable, optimize value network; minimizing MSE between # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ] # if self.value_network is not None: state_value = self.value_network(state.float_features) if self.logged_action_uniform_prior: log_prob_a = torch.zeros_like(min_q_actor_value) target_value = min_q_actor_value else: with torch.no_grad(): log_prob_a = actor_output.log_prob log_prob_a = log_prob_a.clamp(-20.0, 20.0) target_value = ( min_q_actor_value - self.entropy_temperature * log_prob_a ) value_loss = F.mse_loss(state_value, target_value.detach()) value_loss.backward() self._maybe_run_optimizer( self.value_network_optimizer, self.minibatches_per_step ) # Use the soft update rule to update the target networks if self.value_network is not None: self._maybe_soft_update( self.value_network, self.value_network_target, self.tau, self.minibatches_per_step, ) else: self._maybe_soft_update( self.q1_network, self.q1_network_target, self.tau, self.minibatches_per_step, ) if self.q2_network is not None: self._maybe_soft_update( self.q2_network, self.q2_network_target, self.tau, self.minibatches_per_step, ) # Logging at the end to schedule all the cuda operations first if ( self.tensorboard_logging_freq is not None and self.minibatch % self.tensorboard_logging_freq == 0 ): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) SummaryWriterContext.add_histogram("log_prob_a", log_prob_a) if self.value_network: SummaryWriterContext.add_histogram("value_network/target", target_value) SummaryWriterContext.add_histogram( "q_network/next_state_value", next_state_value ) SummaryWriterContext.add_histogram( "q_network/target_q_value", target_q_value ) SummaryWriterContext.add_histogram( "actor/min_q_actor_value", min_q_actor_value ) SummaryWriterContext.add_histogram( "actor/action_log_prob", actor_output.log_prob ) SummaryWriterContext.add_histogram("actor/loss", actor_loss) self.loss_reporter.report( td_loss=float(q1_loss), reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, model_propensities=actor_output.log_prob.exp(), model_values=min_q_actor_value, )
def create_from_tensors( cls, trainer: DQNTrainer, mdp_ids: np.ndarray, sequence_numbers: torch.Tensor, states: rlt.PreprocessedFeatureVector, actions: rlt.PreprocessedFeatureVector, propensities: torch.Tensor, rewards: torch.Tensor, possible_actions_mask: torch.Tensor, possible_actions: Optional[rlt.PreprocessedFeatureVector] = None, max_num_actions: Optional[int] = None, metrics: Optional[torch.Tensor] = None, ): # Switch to evaluation mode for the network old_q_train_state = trainer.q_network.training old_reward_train_state = trainer.reward_network.training trainer.q_network.train(False) trainer.reward_network.train(False) if max_num_actions: # Parametric model CPE state_action_pairs = rlt.PreprocessedStateAction( state=states, action=actions ) tiled_state = states.float_features.repeat(1, max_num_actions).reshape( -1, states.float_features.shape[1] ) assert possible_actions is not None # Get Q-value of action taken possible_actions_state_concat = rlt.PreprocessedStateAction( state=rlt.PreprocessedFeatureVector(float_features=tiled_state), action=possible_actions, ) # Parametric actions # FIXME: model_values and model propensities should be calculated # as in discrete dqn model model_values = trainer.q_network( possible_actions_state_concat ).q_value # type: ignore optimal_q_values = model_values eval_action_idxs = None assert ( model_values.shape[0] * model_values.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1] ), ( "Invalid shapes: " + str(model_values.shape) + " != " + str(possible_actions_mask.shape) ) model_values = model_values.reshape(possible_actions_mask.shape) model_propensities = masked_softmax( model_values, possible_actions_mask, trainer.rl_temperature ) model_rewards = trainer.reward_network( possible_actions_state_concat ).q_value # type: ignore assert ( model_rewards.shape[0] * model_rewards.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1] ), ( "Invalid shapes: " + str(model_rewards.shape) + " != " + str(possible_actions_mask.shape) ) model_rewards = model_rewards.reshape(possible_actions_mask.shape) model_values_for_logged_action = trainer.q_network( state_action_pairs ).q_value model_rewards_for_logged_action = trainer.reward_network( state_action_pairs ).q_value action_mask = ( torch.abs(model_values - model_values_for_logged_action) < 1e-3 ).float() model_metrics = None model_metrics_for_logged_action = None model_metrics_values = None model_metrics_values_for_logged_action = None else: num_actions = trainer.num_actions action_mask = actions.float() # type: ignore # Switch to evaluation mode for the network old_q_cpe_train_state = trainer.q_network_cpe.training trainer.q_network_cpe.train(False) # Discrete actions rewards = trainer.boost_rewards(rewards, actions) # type: ignore model_values = trainer.q_network_cpe( rlt.PreprocessedState(state=states) ).q_values[:, 0:num_actions] optimal_q_values = trainer.get_detached_q_values( states # type: ignore )[ # type: ignore 0 ] # type: ignore eval_action_idxs = trainer.get_max_q_values( # type: ignore optimal_q_values, possible_actions_mask )[1] model_propensities = masked_softmax( optimal_q_values, possible_actions_mask, trainer.rl_temperature ) assert model_values.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) assert model_values.shape == possible_actions_mask.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(possible_actions_mask.shape) # type: ignore ) model_values_for_logged_action = torch.sum( model_values * action_mask, dim=1, keepdim=True ) rewards_and_metric_rewards = trainer.reward_network( rlt.PreprocessedState(state=states) ) # In case we reuse the modular for Q-network if hasattr(rewards_and_metric_rewards, "q_values"): rewards_and_metric_rewards = rewards_and_metric_rewards.q_values model_rewards = rewards_and_metric_rewards[:, 0:num_actions] assert model_rewards.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_rewards.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) model_rewards_for_logged_action = torch.sum( model_rewards * action_mask, dim=1, keepdim=True ) model_metrics = rewards_and_metric_rewards[:, num_actions:] assert model_metrics.shape[1] % num_actions == 0, ( "Invalid metrics shape: " + str(model_metrics.shape) + " " + str(num_actions) ) num_metrics = model_metrics.shape[1] // num_actions if num_metrics == 0: model_metrics_values = None model_metrics_for_logged_action = None model_metrics_values_for_logged_action = None else: model_metrics_values = trainer.q_network_cpe( rlt.PreprocessedState(state=states) ) # Backward compatility if hasattr(model_metrics_values, "q_values"): model_metrics_values = model_metrics_values.q_values model_metrics_values = model_metrics_values[:, num_actions:] assert ( model_metrics_values.shape[1] == num_actions * num_metrics ), ( # type: ignore "Invalid shape: " + str(model_metrics_values.shape[1]) # type: ignore + " != " + str(actions.shape[1] * num_metrics) # type: ignore ) model_metrics_for_logged_action_list = [] model_metrics_values_for_logged_action_list = [] for metric_index in range(num_metrics): metric_start = metric_index * num_actions metric_end = (metric_index + 1) * num_actions model_metrics_for_logged_action_list.append( torch.sum( model_metrics[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_values_for_logged_action_list.append( torch.sum( model_metrics_values[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_for_logged_action = torch.cat( model_metrics_for_logged_action_list, dim=1 ) model_metrics_values_for_logged_action = torch.cat( model_metrics_values_for_logged_action_list, dim=1 ) # Switch back to the old mode trainer.q_network_cpe.train(old_q_cpe_train_state) # type: ignore # Switch back to the old mode trainer.q_network.train(old_q_train_state) # type: ignore trainer.reward_network.train(old_reward_train_state) # type: ignore return cls( mdp_id=mdp_ids, sequence_number=sequence_numbers, logged_propensities=propensities, logged_rewards=rewards, action_mask=action_mask, model_rewards=model_rewards, model_rewards_for_logged_action=model_rewards_for_logged_action, model_values=model_values, model_values_for_logged_action=model_values_for_logged_action, model_metrics_values=model_metrics_values, model_metrics_values_for_logged_action=model_metrics_values_for_logged_action, model_propensities=model_propensities, logged_metrics=metrics, model_metrics=model_metrics, model_metrics_for_logged_action=model_metrics_for_logged_action, # Will compute later logged_values=None, logged_metrics_values=None, possible_actions_mask=possible_actions_mask, optimal_q_values=optimal_q_values, eval_action_idxs=eval_action_idxs, )
def train(self, training_batch) -> None: """ IMPORTANT: the input action here is assumed to be preprocessed to match the range of the output of the actor. """ if hasattr(training_batch, "as_policy_network_training_batch"): training_batch = training_batch.as_policy_network_training_batch() learning_input = training_batch.training_input self.minibatch += 1 state = learning_input.state action = learning_input.action next_state = learning_input.next_state reward = learning_input.reward not_done_mask = learning_input.not_terminal action = self._maybe_scale_action_in_train(action.float_features) max_action = (self.max_action_range_tensor_training if self.max_action_range_tensor_training else torch.ones( action.shape, device=self.device)) min_action = (self.min_action_range_tensor_serving if self.min_action_range_tensor_serving else -torch.ones(action.shape, device=self.device)) # Compute current value estimates current_state_action = rlt.PreprocessedStateAction( state=state, action=rlt.PreprocessedFeatureVector(float_features=action)) q1_value = self.q1_network(current_state_action).q_value if self.q2_network: q2_value = self.q2_network(current_state_action).q_value actor_action = self.actor_network( rlt.PreprocessedState(state=state)).action # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s'))) with torch.no_grad(): next_actor = self.actor_network_target( rlt.PreprocessedState(state=next_state)).action next_actor += (torch.randn_like(next_actor) * self.target_policy_smoothing).clamp( -self.noise_clip, self.noise_clip) next_actor = torch.max(torch.min(next_actor, max_action), min_action) next_state_actor = rlt.PreprocessedStateAction( state=next_state, action=rlt.PreprocessedFeatureVector( float_features=next_actor), ) next_state_value = self.q1_network_target(next_state_actor).q_value if self.q2_network is not None: next_state_value = torch.min( next_state_value, self.q2_network_target(next_state_actor).q_value) target_q_value = ( reward + self.gamma * next_state_value * not_done_mask.float()) # Optimize Q1 and Q2 q1_loss = F.mse_loss(q1_value, target_q_value) q1_loss.backward() self._maybe_run_optimizer(self.q1_network_optimizer, self.minibatches_per_step) if self.q2_network: q2_loss = F.mse_loss(q2_value, target_q_value) q2_loss.backward() self._maybe_run_optimizer(self.q2_network_optimizer, self.minibatches_per_step) # Only update actor and target networks after a fixed number of Q updates if self.minibatch % self.delayed_policy_update == 0: actor_loss = -self.q1_network( rlt.PreprocessedStateAction( state=state, action=rlt.PreprocessedFeatureVector( float_features=actor_action), )).q_value.mean() actor_loss.backward() self._maybe_run_optimizer(self.actor_network_optimizer, self.minibatches_per_step) # Use the soft update rule to update the target networks self._maybe_soft_update( self.q1_network, self.q1_network_target, self.tau, self.minibatches_per_step, ) self._maybe_soft_update( self.actor_network, self.actor_network_target, self.tau, self.minibatches_per_step, ) if self.q2_network is not None: self._maybe_soft_update( self.q2_network, self.q2_network_target, self.tau, self.minibatches_per_step, ) # Logging at the end to schedule all the cuda operations first if (self.tensorboard_logging_freq != 0 and self.minibatch % self.tensorboard_logging_freq == 0): SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value) if self.q2_network: SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value) SummaryWriterContext.add_histogram("q_network/next_state_value", next_state_value) SummaryWriterContext.add_histogram("q_network/target_q_value", target_q_value) SummaryWriterContext.add_histogram("actor/loss", actor_loss) self.loss_reporter.report( td_loss=float(q1_loss), reward_loss=None, logged_rewards=reward, model_values_on_logged_actions=q1_value, )
def test_seq2slate_eval_data_page(self): """ Create 3 slate ranking logs and evaluate using Direct Method, Inverse Propensity Scores, and Doubly Robust. The logs are as follows: state: [1, 0, 0], [0, 1, 0], [0, 0, 1] indices in logged slates: [3, 2], [3, 2], [3, 2] model output indices: [2, 3], [3, 2], [2, 3] logged reward: 4, 5, 7 logged propensities: 0.2, 0.5, 0.4 predicted rewards on logged slates: 2, 4, 6 predicted rewards on model outputted slates: 1, 4, 5 Direct Method uses the predicted rewards on model outputted slates. Thus the result is expected to be (1 + 4 + 5) / 3 Inverse Propensity Scores would scale the reward by 1.0 / logged propensities whenever the model output slate matches with the logged slate. Since only the second log matches with the model output, the IPS result is expected to be 5 / 0.5 / 3 Doubly Robust is the sum of the direct method result and propensity-scaled reward difference; the latter is defined as: 1.0 / logged_propensities * (logged reward - predicted reward on logged slate) * Indicator(model slate == logged slate) Since only the second logged slate matches with the model outputted slate, the DR result is expected to be (1 + 4 + 5) / 3 + 1.0 / 0.5 * (5 - 4) / 3 """ batch_size = 3 state_dim = 3 src_seq_len = 2 tgt_seq_len = 2 candidate_dim = 2 reward_net = FakeSeq2SlateRewardNetwork() seq2slate_net = FakeSeq2SlateTransformerNet() baseline_net = nn.Linear(1, 1) trainer = Seq2SlateTrainer( seq2slate_net, baseline_net, parameters=None, minibatch_size=3, use_gpu=False, ) src_seq = torch.eye(candidate_dim).repeat(batch_size, 1, 1) tgt_out_idx = torch.LongTensor([[3, 2], [3, 2], [3, 2]]) tgt_out_seq = src_seq[torch.arange(batch_size). repeat_interleave(tgt_seq_len), # type: ignore tgt_out_idx.flatten() - 2, ].reshape( batch_size, tgt_seq_len, candidate_dim) ptb = rlt.PreprocessedTrainingBatch( training_input=rlt.PreprocessedRankingInput( state=rlt.PreprocessedFeatureVector( float_features=torch.eye(state_dim)), src_seq=rlt.PreprocessedFeatureVector(float_features=src_seq), tgt_out_seq=rlt.PreprocessedFeatureVector( float_features=tgt_out_seq), src_src_mask=torch.ones(batch_size, src_seq_len, src_seq_len), tgt_out_idx=tgt_out_idx, tgt_out_probs=torch.tensor([0.2, 0.5, 0.4]), slate_reward=torch.tensor([4.0, 5.0, 7.0]), ), extras=rlt.ExtraData( sequence_number=torch.tensor([0, 0, 0]), mdp_id=np.array(["0", "1", "2"]), ), ) edp = EvaluationDataPage.create_from_training_batch( ptb, trainer, reward_net) doubly_robust_estimator = DoublyRobustEstimator() direct_method, inverse_propensity, doubly_robust = doubly_robust_estimator.estimate( edp) logger.info(f"{direct_method}, {inverse_propensity}, {doubly_robust}") avg_logged_reward = (4 + 5 + 7) / 3 self.assertAlmostEqual(direct_method.raw, (1 + 4 + 5) / 3, delta=1e-6) self.assertAlmostEqual(direct_method.normalized, direct_method.raw / avg_logged_reward, delta=1e-6) self.assertAlmostEqual(inverse_propensity.raw, 5 / 0.5 / 3, delta=1e-6) self.assertAlmostEqual( inverse_propensity.normalized, inverse_propensity.raw / avg_logged_reward, delta=1e-6, ) self.assertAlmostEqual(doubly_robust.raw, direct_method.raw + 1 / 0.5 * (5 - 4) / 3, delta=1e-6) self.assertAlmostEqual(doubly_robust.normalized, doubly_robust.raw / avg_logged_reward, delta=1e-6)