def create_from_tensors_dqn( cls, trainer: DQNTrainer, mdp_ids: np.ndarray, sequence_numbers: torch.Tensor, states: rlt.PreprocessedFeatureVector, actions: rlt.PreprocessedFeatureVector, propensities: torch.Tensor, rewards: torch.Tensor, possible_actions_mask: torch.Tensor, metrics: Optional[torch.Tensor] = None, ): old_q_train_state = trainer.q_network.training old_reward_train_state = trainer.reward_network.training old_q_cpe_train_state = trainer.q_network_cpe.training trainer.q_network.train(False) trainer.reward_network.train(False) trainer.q_network_cpe.train(False) num_actions = trainer.num_actions action_mask = actions.float() # type: ignore rewards = trainer.boost_rewards(rewards, actions) # type: ignore model_values = trainer.q_network_cpe( rlt.PreprocessedState(state=states) ).q_values[:, 0:num_actions] optimal_q_values, _ = trainer.get_detached_q_values( states # type: ignore ) eval_action_idxs = trainer.get_max_q_values( # type: ignore optimal_q_values, possible_actions_mask )[1] model_propensities = masked_softmax( optimal_q_values, possible_actions_mask, trainer.rl_temperature ) assert model_values.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) assert model_values.shape == possible_actions_mask.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(possible_actions_mask.shape) # type: ignore ) model_values_for_logged_action = torch.sum( model_values * action_mask, dim=1, keepdim=True ) rewards_and_metric_rewards = trainer.reward_network( rlt.PreprocessedState(state=states) ) # In case we reuse the modular for Q-network if hasattr(rewards_and_metric_rewards, "q_values"): rewards_and_metric_rewards = rewards_and_metric_rewards.q_values model_rewards = rewards_and_metric_rewards[:, 0:num_actions] assert model_rewards.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_rewards.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) model_rewards_for_logged_action = torch.sum( model_rewards * action_mask, dim=1, keepdim=True ) model_metrics = rewards_and_metric_rewards[:, num_actions:] assert model_metrics.shape[1] % num_actions == 0, ( "Invalid metrics shape: " + str(model_metrics.shape) + " " + str(num_actions) ) num_metrics = model_metrics.shape[1] // num_actions if num_metrics == 0: model_metrics_values = None model_metrics_for_logged_action = None model_metrics_values_for_logged_action = None else: model_metrics_values = trainer.q_network_cpe( rlt.PreprocessedState(state=states) ) # Backward compatility if hasattr(model_metrics_values, "q_values"): model_metrics_values = model_metrics_values.q_values model_metrics_values = model_metrics_values[:, num_actions:] assert ( model_metrics_values.shape[1] == num_actions * num_metrics ), ( # type: ignore "Invalid shape: " + str(model_metrics_values.shape[1]) # type: ignore + " != " + str(actions.shape[1] * num_metrics) # type: ignore ) model_metrics_for_logged_action_list = [] model_metrics_values_for_logged_action_list = [] for metric_index in range(num_metrics): metric_start = metric_index * num_actions metric_end = (metric_index + 1) * num_actions model_metrics_for_logged_action_list.append( torch.sum( model_metrics[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_values_for_logged_action_list.append( torch.sum( model_metrics_values[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_for_logged_action = torch.cat( model_metrics_for_logged_action_list, dim=1 ) model_metrics_values_for_logged_action = torch.cat( model_metrics_values_for_logged_action_list, dim=1 ) trainer.q_network_cpe.train(old_q_cpe_train_state) # type: ignore trainer.q_network.train(old_q_train_state) # type: ignore trainer.reward_network.train(old_reward_train_state) # type: ignore return cls( mdp_id=mdp_ids, sequence_number=sequence_numbers, logged_propensities=propensities, logged_rewards=rewards, action_mask=action_mask, model_rewards=model_rewards, model_rewards_for_logged_action=model_rewards_for_logged_action, model_values=model_values, model_values_for_logged_action=model_values_for_logged_action, model_metrics_values=model_metrics_values, model_metrics_values_for_logged_action=model_metrics_values_for_logged_action, model_propensities=model_propensities, logged_metrics=metrics, model_metrics=model_metrics, model_metrics_for_logged_action=model_metrics_for_logged_action, # Will compute later logged_values=None, logged_metrics_values=None, possible_actions_mask=possible_actions_mask, optimal_q_values=optimal_q_values, eval_action_idxs=eval_action_idxs, )
def evaluate(self, tdp: PreprocessedTrainingBatch): """ Calculate state feature sensitivity due to actions: randomly permutating actions and see how much the prediction of next state feature deviates. """ mdnrnn_training_input = tdp.training_input assert isinstance(mdnrnn_training_input, PreprocessedMemoryNetworkInput) self.trainer.mdnrnn.mdnrnn.eval() batch_size, seq_len, state_dim = ( mdnrnn_training_input.next_state.float_features.size()) state_feature_num = self.state_feature_num feature_sensitivity = torch.zeros(state_feature_num) state, action, next_state, reward, not_terminal = transpose( mdnrnn_training_input.state.float_features, mdnrnn_training_input.action, mdnrnn_training_input.next_state.float_features, mdnrnn_training_input.reward, mdnrnn_training_input.not_terminal, ) mdnrnn_input = PreprocessedStateAction( state=PreprocessedFeatureVector(float_features=state), action=PreprocessedFeatureVector(float_features=action), ) # the input of mdnrnn has seq-len as the first dimension mdnrnn_output = self.trainer.mdnrnn(mdnrnn_input) predicted_next_state_means = mdnrnn_output.mus shuffled_mdnrnn_input = PreprocessedStateAction( state=PreprocessedFeatureVector(float_features=state), # shuffle the actions action=PreprocessedFeatureVector( float_features=action[:, torch.randperm(batch_size), :]), ) shuffled_mdnrnn_output = self.trainer.mdnrnn(shuffled_mdnrnn_input) shuffled_predicted_next_state_means = shuffled_mdnrnn_output.mus assert (predicted_next_state_means.size() == shuffled_predicted_next_state_means.size() == (seq_len, batch_size, self.trainer.params.num_gaussians, state_dim)) state_feature_boundaries = self.sorted_state_feature_start_indices + [ state_dim ] for i in range(state_feature_num): boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) abs_diff = torch.mean( torch.sum( torch.abs( shuffled_predicted_next_state_means[:, :, :, boundary_start: boundary_end] - predicted_next_state_means[:, :, :, boundary_start:boundary_end] ), dim=3, )) feature_sensitivity[i] = abs_diff.cpu().detach().item() self.trainer.mdnrnn.mdnrnn.train() logger.info("**** Debug tool feature sensitivity ****: {}".format( feature_sensitivity)) return {"feature_sensitivity": feature_sensitivity.numpy()}
def evaluate(self, tdp: PreprocessedTrainingBatch): """ Calculate feature importance: setting each state/action feature to the mean value and observe loss increase. """ assert isinstance(tdp.training_input, PreprocessedMemoryNetworkInput) self.trainer.mdnrnn.mdnrnn.eval() state_features = tdp.training_input.state.float_features action_features = tdp.training_input.action # type: ignore batch_size, seq_len, state_dim = state_features.size() # type: ignore action_dim = action_features.size()[2] # type: ignore action_feature_num = self.action_feature_num state_feature_num = self.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses action_feature_boundaries = self.sorted_action_feature_start_indices + [ action_dim ] state_feature_boundaries = self.sorted_state_feature_start_indices + [ state_dim ] for i in range(action_feature_num): action_features = tdp.training_input.action.reshape( # type: ignore (batch_size * seq_len, action_dim)).data.clone() # if actions are discrete, an action's feature importance is the loss # increase due to setting all actions to this action if self.discrete_action: assert action_dim == action_feature_num action_vec = torch.zeros(action_dim) action_vec[i] = 1 action_features[:] = action_vec # type: ignore # if actions are continuous, an action's feature importance is the loss # increase due to masking this action feature to its mean value else: boundary_start, boundary_end = ( action_feature_boundaries[i], action_feature_boundaries[i + 1], ) action_features[ # type: ignore :, boundary_start: boundary_end] = self.compute_median_feature_value( # type: ignore action_features[:, boundary_start: boundary_end] # type: ignore ) action_features = action_features.reshape( # type: ignore (batch_size, seq_len, action_dim)) # type: ignore new_tdp = PreprocessedTrainingBatch( training_input=PreprocessedMemoryNetworkInput( # type: ignore state=tdp.training_input.state, action=action_features, next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, time_diff=torch.ones_like( tdp.training_input.reward).float(), not_terminal=tdp.training_input. not_terminal, # type: ignore step=None, ), extras=ExtraData(), ) losses = self.trainer.get_loss(new_tdp, state_dim=state_dim, batch_first=True) feature_importance[i] = losses["loss"].cpu().detach().item( ) - orig_loss del losses for i in range(state_feature_num): state_features = tdp.training_input.state.float_features.reshape( # type: ignore (batch_size * seq_len, state_dim)).data.clone() boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) state_features[ # type: ignore :, boundary_start: boundary_end] = self.compute_median_feature_value( state_features[:, boundary_start:boundary_end] # type: ignore ) state_features = state_features.reshape( # type: ignore (batch_size, seq_len, state_dim)) # type: ignore new_tdp = PreprocessedTrainingBatch( training_input=PreprocessedMemoryNetworkInput( # type: ignore state=PreprocessedFeatureVector( float_features=state_features), action=tdp.training_input.action, # type: ignore next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, time_diff=torch.ones_like( tdp.training_input.reward).float(), not_terminal=tdp.training_input. not_terminal, # type: ignore step=None, ), extras=ExtraData(), ) losses = self.trainer.get_loss(new_tdp, state_dim=state_dim, batch_first=True) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss) del losses self.trainer.mdnrnn.mdnrnn.train() logger.info("**** Debug tool feature importance ****: {}".format( feature_importance)) return {"feature_loss_increase": feature_importance.numpy()}
def create_from_tensors( cls, trainer: DQNTrainer, mdp_ids: np.ndarray, sequence_numbers: torch.Tensor, states: rlt.PreprocessedFeatureVector, actions: rlt.PreprocessedFeatureVector, propensities: torch.Tensor, rewards: torch.Tensor, possible_actions_mask: torch.Tensor, possible_actions: Optional[rlt.PreprocessedFeatureVector] = None, max_num_actions: Optional[int] = None, metrics: Optional[torch.Tensor] = None, ): # Switch to evaluation mode for the network old_q_train_state = trainer.q_network.training old_reward_train_state = trainer.reward_network.training trainer.q_network.train(False) trainer.reward_network.train(False) if max_num_actions: # Parametric model CPE state_action_pairs = rlt.PreprocessedStateAction( state=states, action=actions ) tiled_state = states.float_features.repeat(1, max_num_actions).reshape( -1, states.float_features.shape[1] ) assert possible_actions is not None # Get Q-value of action taken possible_actions_state_concat = rlt.PreprocessedStateAction( state=rlt.PreprocessedFeatureVector(float_features=tiled_state), action=possible_actions, ) # Parametric actions # FIXME: model_values and model propensities should be calculated # as in discrete dqn model model_values = trainer.q_network( possible_actions_state_concat ).q_value # type: ignore optimal_q_values = model_values eval_action_idxs = None assert ( model_values.shape[0] * model_values.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1] ), ( "Invalid shapes: " + str(model_values.shape) + " != " + str(possible_actions_mask.shape) ) model_values = model_values.reshape(possible_actions_mask.shape) model_propensities = masked_softmax( model_values, possible_actions_mask, trainer.rl_temperature ) model_rewards = trainer.reward_network( possible_actions_state_concat ).q_value # type: ignore assert ( model_rewards.shape[0] * model_rewards.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1] ), ( "Invalid shapes: " + str(model_rewards.shape) + " != " + str(possible_actions_mask.shape) ) model_rewards = model_rewards.reshape(possible_actions_mask.shape) model_values_for_logged_action = trainer.q_network( state_action_pairs ).q_value model_rewards_for_logged_action = trainer.reward_network( state_action_pairs ).q_value action_mask = ( torch.abs(model_values - model_values_for_logged_action) < 1e-3 ).float() model_metrics = None model_metrics_for_logged_action = None model_metrics_values = None model_metrics_values_for_logged_action = None else: num_actions = trainer.num_actions action_mask = actions.float() # type: ignore # Switch to evaluation mode for the network old_q_cpe_train_state = trainer.q_network_cpe.training trainer.q_network_cpe.train(False) # Discrete actions rewards = trainer.boost_rewards(rewards, actions) # type: ignore model_values = trainer.q_network_cpe( rlt.PreprocessedState(state=states) ).q_values[:, 0:num_actions] optimal_q_values = trainer.get_detached_q_values( states # type: ignore )[ # type: ignore 0 ] # type: ignore eval_action_idxs = trainer.get_max_q_values( # type: ignore optimal_q_values, possible_actions_mask )[1] model_propensities = masked_softmax( optimal_q_values, possible_actions_mask, trainer.rl_temperature ) assert model_values.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) assert model_values.shape == possible_actions_mask.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(possible_actions_mask.shape) # type: ignore ) model_values_for_logged_action = torch.sum( model_values * action_mask, dim=1, keepdim=True ) rewards_and_metric_rewards = trainer.reward_network( rlt.PreprocessedState(state=states) ) # In case we reuse the modular for Q-network if hasattr(rewards_and_metric_rewards, "q_values"): rewards_and_metric_rewards = rewards_and_metric_rewards.q_values model_rewards = rewards_and_metric_rewards[:, 0:num_actions] assert model_rewards.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_rewards.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) model_rewards_for_logged_action = torch.sum( model_rewards * action_mask, dim=1, keepdim=True ) model_metrics = rewards_and_metric_rewards[:, num_actions:] assert model_metrics.shape[1] % num_actions == 0, ( "Invalid metrics shape: " + str(model_metrics.shape) + " " + str(num_actions) ) num_metrics = model_metrics.shape[1] // num_actions if num_metrics == 0: model_metrics_values = None model_metrics_for_logged_action = None model_metrics_values_for_logged_action = None else: model_metrics_values = trainer.q_network_cpe( rlt.PreprocessedState(state=states) ) # Backward compatility if hasattr(model_metrics_values, "q_values"): model_metrics_values = model_metrics_values.q_values model_metrics_values = model_metrics_values[:, num_actions:] assert ( model_metrics_values.shape[1] == num_actions * num_metrics ), ( # type: ignore "Invalid shape: " + str(model_metrics_values.shape[1]) # type: ignore + " != " + str(actions.shape[1] * num_metrics) # type: ignore ) model_metrics_for_logged_action_list = [] model_metrics_values_for_logged_action_list = [] for metric_index in range(num_metrics): metric_start = metric_index * num_actions metric_end = (metric_index + 1) * num_actions model_metrics_for_logged_action_list.append( torch.sum( model_metrics[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_values_for_logged_action_list.append( torch.sum( model_metrics_values[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_for_logged_action = torch.cat( model_metrics_for_logged_action_list, dim=1 ) model_metrics_values_for_logged_action = torch.cat( model_metrics_values_for_logged_action_list, dim=1 ) # Switch back to the old mode trainer.q_network_cpe.train(old_q_cpe_train_state) # type: ignore # Switch back to the old mode trainer.q_network.train(old_q_train_state) # type: ignore trainer.reward_network.train(old_reward_train_state) # type: ignore return cls( mdp_id=mdp_ids, sequence_number=sequence_numbers, logged_propensities=propensities, logged_rewards=rewards, action_mask=action_mask, model_rewards=model_rewards, model_rewards_for_logged_action=model_rewards_for_logged_action, model_values=model_values, model_values_for_logged_action=model_values_for_logged_action, model_metrics_values=model_metrics_values, model_metrics_values_for_logged_action=model_metrics_values_for_logged_action, model_propensities=model_propensities, logged_metrics=metrics, model_metrics=model_metrics, model_metrics_for_logged_action=model_metrics_for_logged_action, # Will compute later logged_values=None, logged_metrics_values=None, possible_actions_mask=possible_actions_mask, optimal_q_values=optimal_q_values, eval_action_idxs=eval_action_idxs, )