def create_from_tensors_dqn( cls, trainer: DQNTrainer, mdp_ids: np.ndarray, sequence_numbers: torch.Tensor, states: rlt.PreprocessedFeatureVector, actions: rlt.PreprocessedFeatureVector, propensities: torch.Tensor, rewards: torch.Tensor, possible_actions_mask: torch.Tensor, metrics: Optional[torch.Tensor] = None, ): old_q_train_state = trainer.q_network.training old_reward_train_state = trainer.reward_network.training old_q_cpe_train_state = trainer.q_network_cpe.training trainer.q_network.train(False) trainer.reward_network.train(False) trainer.q_network_cpe.train(False) num_actions = trainer.num_actions action_mask = actions.float() # type: ignore rewards = trainer.boost_rewards(rewards, actions) # type: ignore model_values = trainer.q_network_cpe( rlt.PreprocessedState(state=states) ).q_values[:, 0:num_actions] optimal_q_values, _ = trainer.get_detached_q_values( states # type: ignore ) eval_action_idxs = trainer.get_max_q_values( # type: ignore optimal_q_values, possible_actions_mask )[1] model_propensities = masked_softmax( optimal_q_values, possible_actions_mask, trainer.rl_temperature ) assert model_values.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) assert model_values.shape == possible_actions_mask.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(possible_actions_mask.shape) # type: ignore ) model_values_for_logged_action = torch.sum( model_values * action_mask, dim=1, keepdim=True ) rewards_and_metric_rewards = trainer.reward_network( rlt.PreprocessedState(state=states) ) # In case we reuse the modular for Q-network if hasattr(rewards_and_metric_rewards, "q_values"): rewards_and_metric_rewards = rewards_and_metric_rewards.q_values model_rewards = rewards_and_metric_rewards[:, 0:num_actions] assert model_rewards.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_rewards.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) model_rewards_for_logged_action = torch.sum( model_rewards * action_mask, dim=1, keepdim=True ) model_metrics = rewards_and_metric_rewards[:, num_actions:] assert model_metrics.shape[1] % num_actions == 0, ( "Invalid metrics shape: " + str(model_metrics.shape) + " " + str(num_actions) ) num_metrics = model_metrics.shape[1] // num_actions if num_metrics == 0: model_metrics_values = None model_metrics_for_logged_action = None model_metrics_values_for_logged_action = None else: model_metrics_values = trainer.q_network_cpe( rlt.PreprocessedState(state=states) ) # Backward compatility if hasattr(model_metrics_values, "q_values"): model_metrics_values = model_metrics_values.q_values model_metrics_values = model_metrics_values[:, num_actions:] assert ( model_metrics_values.shape[1] == num_actions * num_metrics ), ( # type: ignore "Invalid shape: " + str(model_metrics_values.shape[1]) # type: ignore + " != " + str(actions.shape[1] * num_metrics) # type: ignore ) model_metrics_for_logged_action_list = [] model_metrics_values_for_logged_action_list = [] for metric_index in range(num_metrics): metric_start = metric_index * num_actions metric_end = (metric_index + 1) * num_actions model_metrics_for_logged_action_list.append( torch.sum( model_metrics[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_values_for_logged_action_list.append( torch.sum( model_metrics_values[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_for_logged_action = torch.cat( model_metrics_for_logged_action_list, dim=1 ) model_metrics_values_for_logged_action = torch.cat( model_metrics_values_for_logged_action_list, dim=1 ) trainer.q_network_cpe.train(old_q_cpe_train_state) # type: ignore trainer.q_network.train(old_q_train_state) # type: ignore trainer.reward_network.train(old_reward_train_state) # type: ignore return cls( mdp_id=mdp_ids, sequence_number=sequence_numbers, logged_propensities=propensities, logged_rewards=rewards, action_mask=action_mask, model_rewards=model_rewards, model_rewards_for_logged_action=model_rewards_for_logged_action, model_values=model_values, model_values_for_logged_action=model_values_for_logged_action, model_metrics_values=model_metrics_values, model_metrics_values_for_logged_action=model_metrics_values_for_logged_action, model_propensities=model_propensities, logged_metrics=metrics, model_metrics=model_metrics, model_metrics_for_logged_action=model_metrics_for_logged_action, # Will compute later logged_values=None, logged_metrics_values=None, possible_actions_mask=possible_actions_mask, optimal_q_values=optimal_q_values, eval_action_idxs=eval_action_idxs, )
def create_from_tensors( cls, trainer: DQNTrainer, mdp_ids: np.ndarray, sequence_numbers: torch.Tensor, states: Union[mt.State, torch.Tensor], actions: Union[mt.Action, torch.Tensor], propensities: torch.Tensor, rewards: torch.Tensor, possible_actions_mask: torch.Tensor, possible_actions: Optional[mt.FeatureVector] = None, max_num_actions: Optional[int] = None, metrics: Optional[torch.Tensor] = None, ): # Switch to evaluation mode for the network old_q_train_state = trainer.q_network.training old_reward_train_state = trainer.reward_network.training trainer.q_network.train(False) trainer.reward_network.train(False) if max_num_actions: # Parametric model CPE state_action_pairs = mt.StateAction( # type: ignore state=states, action=actions) tiled_state = mt.FeatureVector( states.float_features.repeat( # type: ignore 1, max_num_actions).reshape( # type: ignore -1, states.float_features.shape[1] # type: ignore )) # Get Q-value of action taken possible_actions_state_concat = mt.StateAction( # type: ignore state=tiled_state, action=possible_actions # type: ignore ) # Parametric actions # FIXME: model_values and model propensities should be calculated # as in discrete dqn model model_values = trainer.q_network( possible_actions_state_concat).q_value # type: ignore optimal_q_values = model_values eval_action_idxs = None assert (model_values.shape[0] * model_values.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]), ( "Invalid shapes: " + str(model_values.shape) + " != " + str(possible_actions_mask.shape)) model_values = model_values.reshape(possible_actions_mask.shape) model_propensities = masked_softmax(model_values, possible_actions_mask, trainer.rl_temperature) model_rewards = trainer.reward_network( possible_actions_state_concat).q_value # type: ignore assert (model_rewards.shape[0] * model_rewards.shape[1] == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]), ( "Invalid shapes: " + str(model_rewards.shape) + " != " + str(possible_actions_mask.shape)) model_rewards = model_rewards.reshape(possible_actions_mask.shape) model_values_for_logged_action = trainer.q_network( state_action_pairs).q_value model_rewards_for_logged_action = trainer.reward_network( state_action_pairs).q_value action_mask = ( torch.abs(model_values - model_values_for_logged_action) < 1e-3).float() model_metrics = None model_metrics_for_logged_action = None model_metrics_values = None model_metrics_values_for_logged_action = None else: if isinstance(states, mt.State): states = mt.StateInput(state=states) # type: ignore num_actions = trainer.num_actions action_mask = actions.float() # type: ignore # Switch to evaluation mode for the network old_q_cpe_train_state = trainer.q_network_cpe.training trainer.q_network_cpe.train(False) # Discrete actions rewards = trainer.boost_rewards(rewards, actions) # type: ignore model_values = trainer.q_network_cpe( states).q_values[:, 0:num_actions] optimal_q_values = trainer.get_detached_q_values( states.state # type: ignore )[ # type: ignore 0] # type: ignore eval_action_idxs = trainer.get_max_q_values( # type: ignore optimal_q_values, possible_actions_mask)[1] model_propensities = masked_softmax(optimal_q_values, possible_actions_mask, trainer.rl_temperature) assert model_values.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) assert model_values.shape == possible_actions_mask.shape, ( # type: ignore "Invalid shape: " + str(model_values.shape) # type: ignore + " != " + str(possible_actions_mask.shape) # type: ignore ) model_values_for_logged_action = torch.sum(model_values * action_mask, dim=1, keepdim=True) rewards_and_metric_rewards = trainer.reward_network(states) # In case we reuse the modular for Q-network if hasattr(rewards_and_metric_rewards, "q_values"): rewards_and_metric_rewards = rewards_and_metric_rewards.q_values model_rewards = rewards_and_metric_rewards[:, 0:num_actions] assert model_rewards.shape == actions.shape, ( # type: ignore "Invalid shape: " + str(model_rewards.shape) # type: ignore + " != " + str(actions.shape) # type: ignore ) model_rewards_for_logged_action = torch.sum(model_rewards * action_mask, dim=1, keepdim=True) model_metrics = rewards_and_metric_rewards[:, num_actions:] assert model_metrics.shape[1] % num_actions == 0, ( "Invalid metrics shape: " + str(model_metrics.shape) + " " + str(num_actions)) num_metrics = model_metrics.shape[1] // num_actions if num_metrics == 0: model_metrics_values = None model_metrics_for_logged_action = None model_metrics_values_for_logged_action = None else: model_metrics_values = trainer.q_network_cpe(states) # Backward compatility if hasattr(model_metrics_values, "q_values"): model_metrics_values = model_metrics_values.q_values model_metrics_values = model_metrics_values[:, num_actions:] assert (model_metrics_values.shape[1] == num_actions * num_metrics), ( # type: ignore "Invalid shape: " + str(model_metrics_values.shape[1]) # type: ignore + " != " + str(actions.shape[1] * num_metrics) # type: ignore ) model_metrics_for_logged_action_list = [] model_metrics_values_for_logged_action_list = [] for metric_index in range(num_metrics): metric_start = metric_index * num_actions metric_end = (metric_index + 1) * num_actions model_metrics_for_logged_action_list.append( torch.sum( model_metrics[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, )) model_metrics_values_for_logged_action_list.append( torch.sum( model_metrics_values[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, )) model_metrics_for_logged_action = torch.cat( model_metrics_for_logged_action_list, dim=1) model_metrics_values_for_logged_action = torch.cat( model_metrics_values_for_logged_action_list, dim=1) # Switch back to the old mode trainer.q_network_cpe.train(old_q_cpe_train_state) # type: ignore # Switch back to the old mode trainer.q_network.train(old_q_train_state) # type: ignore trainer.reward_network.train(old_reward_train_state) # type: ignore return cls( mdp_id=mdp_ids, sequence_number=sequence_numbers, logged_propensities=propensities, logged_rewards=rewards, action_mask=action_mask, model_rewards=model_rewards, model_rewards_for_logged_action=model_rewards_for_logged_action, model_values=model_values, model_values_for_logged_action=model_values_for_logged_action, model_metrics_values=model_metrics_values, model_metrics_values_for_logged_action= model_metrics_values_for_logged_action, model_propensities=model_propensities, logged_metrics=metrics, model_metrics=model_metrics, model_metrics_for_logged_action=model_metrics_for_logged_action, # Will compute later logged_values=None, logged_metrics_values=None, possible_actions_mask=possible_actions_mask, optimal_q_values=optimal_q_values, eval_action_idxs=eval_action_idxs, )