def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput: """ Act randomly regardless of the observation. """ obs: torch.Tensor = obs.float_features assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)" batch_size = obs.size(0) # pyre-fixme[6]: Expected `Union[torch.Size, torch.Tensor]` for 1st param # but got `Tuple[int]`. action = self.dist.sample((batch_size, )) log_prob = self.dist.log_prob(action) return rlt.ActorOutput(action=action, log_prob=log_prob)
def evaluate(self, batch: PreprocessedMemoryNetworkInput): """ Calculate state feature sensitivity due to actions: randomly permutating actions and see how much the prediction of next state feature deviates. """ assert isinstance(batch, PreprocessedMemoryNetworkInput) self.trainer.memory_network.mdnrnn.eval() seq_len, batch_size, state_dim = batch.next_state.float_features.size() state_feature_num = self.state_feature_num feature_sensitivity = torch.zeros(state_feature_num) # the input of world_model has seq-len as the first dimension mdnrnn_output = self.trainer.memory_network(batch.state, FeatureData(batch.action)) predicted_next_state_means = mdnrnn_output.mus shuffled_mdnrnn_output = self.trainer.memory_network( batch.state, # shuffle the actions FeatureData(batch.action[:, torch.randperm(batch_size), :]), ) shuffled_predicted_next_state_means = shuffled_mdnrnn_output.mus assert (predicted_next_state_means.size() == shuffled_predicted_next_state_means.size() == (seq_len, batch_size, self.trainer.params.num_gaussians, state_dim)) state_feature_boundaries = self.sorted_state_feature_start_indices + [ state_dim ] for i in range(state_feature_num): boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) abs_diff = torch.mean( torch.sum( torch.abs( shuffled_predicted_next_state_means[:, :, :, boundary_start: boundary_end] - predicted_next_state_means[:, :, :, boundary_start:boundary_end] ), dim=3, )) feature_sensitivity[i] = abs_diff.cpu().detach().item() self.trainer.memory_network.mdnrnn.train() logger.info("**** Debug tool feature sensitivity ****: {}".format( feature_sensitivity)) return {"feature_sensitivity": feature_sensitivity.numpy()}
def act( self, obs: rlt.FeatureData, possible_actions_mask: Optional[np.ndarray] = None ) -> rlt.ActorOutput: """ Act randomly regardless of the observation. """ obs: torch.Tensor = obs.float_features assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)" batch_size = obs.size(0) # pyre-fixme[6]: Expected `Union[torch.Size, torch.Tensor]` for 1st param # but got `Tuple[int]`. action = self.dist.sample((batch_size,)) # sum over action_dim (since assuming i.i.d. per coordinate) log_prob = self.dist.log_prob(action).sum(1) return rlt.ActorOutput(action=action, log_prob=log_prob)
def get_parametric_input(max_num_actions: int, obs: rlt.FeatureData): assert (len(obs.float_features.shape) == 2 ), f"{obs.float_features.shape} is not (batch_size, state_dim)." batch_size, _ = obs.float_features.shape possible_actions = get_possible_actions_for_gym( batch_size, max_num_actions).to(obs.float_features.device) return obs.get_tiled_batch(max_num_actions), possible_actions
def _get_unmasked_q_values(self, q_network, state: rlt.FeatureData, slate: rlt.DocList) -> torch.Tensor: """ Gets the q values from the model and target networks """ batch_size, slate_size, _ = slate.float_features.shape # TODO: Probably should create a new model type return q_network(state.repeat_interleave(slate_size, dim=0), slate.as_feature_data()).view(batch_size, slate_size)
def _get_unmask_q_values( self, q_network, state: rlt.FeatureData, action: rlt.PreprocessedSlateFeatureVector, ) -> torch.Tensor: batch_size, slate_size, _ = action.float_features.shape return q_network( state.repeat_interleave(slate_size, dim=0), action.as_preprocessed_feature_vector(), ).view(batch_size, slate_size)
def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput: """ Act randomly regardless of the observation. """ obs: torch.Tensor = obs.float_features assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)" batch_size = obs.shape[0] weights = torch.ones((batch_size, self.num_actions)) # sample a random action m = torch.distributions.Categorical(weights) raw_action = m.sample() action = F.one_hot(raw_action, self.num_actions) log_prob = m.log_prob(raw_action).float() return rlt.ActorOutput(action=action, log_prob=log_prob)
def score(preprocessed_obs: rlt.FeatureData) -> torch.Tensor: tiled_state = preprocessed_obs.repeat_interleave(repeats=num_actions, axis=0) actions = rlt.FeatureData(float_features=torch.eye(num_actions)) q_network.eval() scores = q_network(tiled_state.state, actions).view(-1, num_actions) assert ( scores.size(1) == num_actions ), f"scores size is {scores.size(0)}, num_actions is {num_actions}" q_network.train() return F.log_softmax(scores, dim=-1)
def score(state: rlt.FeatureData) -> torch.Tensor: tiled_state = state.repeat_interleave(repeats=num_candidates, axis=0) candidate_docs = state.candidate_docs assert candidate_docs is not None actions = candidate_docs.as_feature_data() q_network.eval() scores = q_network(tiled_state, actions).view(-1, num_candidates) q_network.train() select_prob = F.softmax(candidate_docs.value, dim=1) assert select_prob.shape == scores.shape return select_prob * scores
def act( self, obs: rlt.FeatureData, possible_actions_mask: Optional[np.ndarray] = None ) -> rlt.ActorOutput: """ Act randomly regardless of the observation. """ obs: torch.Tensor = obs.float_features assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)" assert obs.shape[0] == 1, f"obs has shape {obs.shape} (0th dim != 1)" batch_size = obs.shape[0] scores = torch.ones((batch_size, self.num_actions)) scores = apply_possible_actions_mask( scores, possible_actions_mask, invalid_score=0.0 ) # sample a random action m = torch.distributions.Categorical(scores) raw_action = m.sample() action = F.one_hot(raw_action, self.num_actions) log_prob = m.log_prob(raw_action).float() return rlt.ActorOutput(action=action, log_prob=log_prob)
def create_from_tensors_dqn( cls, trainer: DQNTrainer, mdp_ids: torch.Tensor, sequence_numbers: torch.Tensor, states: rlt.FeatureData, actions: rlt.FeatureData, propensities: torch.Tensor, rewards: torch.Tensor, possible_actions_mask: torch.Tensor, metrics: Optional[torch.Tensor] = None, ): old_q_train_state = trainer.q_network.training # pyre-fixme[16]: `DQNTrainer` has no attribute `reward_network`. old_reward_train_state = trainer.reward_network.training # pyre-fixme[16]: `DQNTrainer` has no attribute `q_network_cpe`. old_q_cpe_train_state = trainer.q_network_cpe.training trainer.q_network.train(False) trainer.reward_network.train(False) trainer.q_network_cpe.train(False) num_actions = trainer.num_actions action_mask = actions.float() rewards = trainer.boost_rewards(rewards, actions) model_values = trainer.q_network_cpe(states)[:, 0:num_actions] optimal_q_values, _ = trainer.get_detached_q_values(states) # Do we ever really use eval_action_idxs? eval_action_idxs = trainer.get_max_q_values( optimal_q_values, possible_actions_mask )[1] model_propensities = masked_softmax( optimal_q_values, possible_actions_mask, trainer.rl_temperature ) assert model_values.shape == actions.shape, ( "Invalid shape: " + str(model_values.shape) + " != " + str(actions.shape) ) assert model_values.shape == possible_actions_mask.shape, ( "Invalid shape: " + str(model_values.shape) + " != " + str(possible_actions_mask.shape) ) model_values_for_logged_action = torch.sum( model_values * action_mask, dim=1, keepdim=True ) rewards_and_metric_rewards = trainer.reward_network(states) # In case we reuse the modular for Q-network if hasattr(rewards_and_metric_rewards, "q_values"): rewards_and_metric_rewards = rewards_and_metric_rewards model_rewards = rewards_and_metric_rewards[:, 0:num_actions] assert model_rewards.shape == actions.shape, ( "Invalid shape: " + str(model_rewards.shape) + " != " + str(actions.shape) ) model_rewards_for_logged_action = torch.sum( model_rewards * action_mask, dim=1, keepdim=True ) model_metrics = rewards_and_metric_rewards[:, num_actions:] assert model_metrics.shape[1] % num_actions == 0, ( "Invalid metrics shape: " + str(model_metrics.shape) + " " + str(num_actions) ) num_metrics = model_metrics.shape[1] // num_actions if num_metrics == 0: model_metrics_values = None model_metrics_for_logged_action = None model_metrics_values_for_logged_action = None else: model_metrics_values = trainer.q_network_cpe(states) # Backward compatility if hasattr(model_metrics_values, "q_values"): model_metrics_values = model_metrics_values model_metrics_values = model_metrics_values[:, num_actions:] assert model_metrics_values.shape[1] == num_actions * num_metrics, ( "Invalid shape: " + str(model_metrics_values.shape[1]) + " != " + str(actions.shape[1] * num_metrics) ) model_metrics_for_logged_action_list = [] model_metrics_values_for_logged_action_list = [] for metric_index in range(num_metrics): metric_start = metric_index * num_actions metric_end = (metric_index + 1) * num_actions model_metrics_for_logged_action_list.append( torch.sum( model_metrics[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_values_for_logged_action_list.append( torch.sum( model_metrics_values[:, metric_start:metric_end] * action_mask, dim=1, keepdim=True, ) ) model_metrics_for_logged_action = torch.cat( model_metrics_for_logged_action_list, dim=1 ) model_metrics_values_for_logged_action = torch.cat( model_metrics_values_for_logged_action_list, dim=1 ) trainer.q_network_cpe.train(old_q_cpe_train_state) trainer.q_network.train(old_q_train_state) trainer.reward_network.train(old_reward_train_state) return cls( mdp_id=mdp_ids, sequence_number=sequence_numbers, logged_propensities=propensities, logged_rewards=rewards, action_mask=action_mask, model_rewards=model_rewards, model_rewards_for_logged_action=model_rewards_for_logged_action, model_values=model_values, model_values_for_logged_action=model_values_for_logged_action, model_metrics_values=model_metrics_values, model_metrics_values_for_logged_action=model_metrics_values_for_logged_action, model_propensities=model_propensities, logged_metrics=metrics, model_metrics=model_metrics, model_metrics_for_logged_action=model_metrics_for_logged_action, # Will compute later logged_values=None, logged_metrics_values=None, possible_actions_mask=possible_actions_mask, optimal_q_values=optimal_q_values, eval_action_idxs=eval_action_idxs, )
def evaluate(self, batch: PreprocessedMemoryNetworkInput): """ Calculate feature importance: setting each state/action feature to the mean value and observe loss increase. """ self.trainer.memory_network.mdnrnn.eval() state_features = batch.state.float_features action_features = batch.action seq_len, batch_size, state_dim = state_features.size() action_dim = action_features.size()[2] action_feature_num = self.action_feature_num state_feature_num = self.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) orig_losses = self.trainer.get_loss(batch, state_dim=state_dim) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses action_feature_boundaries = self.sorted_action_feature_start_indices + [ action_dim ] state_feature_boundaries = self.sorted_state_feature_start_indices + [ state_dim ] for i in range(action_feature_num): action_features = batch.action.reshape( (batch_size * seq_len, action_dim)).data.clone() # if actions are discrete, an action's feature importance is the loss # increase due to setting all actions to this action if self.discrete_action: assert action_dim == action_feature_num action_vec = torch.zeros(action_dim) action_vec[i] = 1 action_features[:] = action_vec # if actions are continuous, an action's feature importance is the loss # increase due to masking this action feature to its mean value else: boundary_start, boundary_end = ( action_feature_boundaries[i], action_feature_boundaries[i + 1], ) action_features[:, boundary_start: boundary_end] = self.compute_median_feature_value( action_features[:, boundary_start: boundary_end]) action_features = action_features.reshape( (seq_len, batch_size, action_dim)) new_batch = PreprocessedMemoryNetworkInput( state=batch.state, action=action_features, next_state=batch.next_state, reward=batch.reward, time_diff=torch.ones_like(batch.reward).float(), not_terminal=batch.not_terminal, step=None, ) losses = self.trainer.get_loss(new_batch, state_dim=state_dim) feature_importance[i] = losses["loss"].cpu().detach().item( ) - orig_loss del losses for i in range(state_feature_num): state_features = batch.state.float_features.reshape( (batch_size * seq_len, state_dim)).data.clone() boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) state_features[:, boundary_start: boundary_end] = self.compute_median_feature_value( state_features[:, boundary_start:boundary_end]) state_features = state_features.reshape( (seq_len, batch_size, state_dim)) new_batch = PreprocessedMemoryNetworkInput( state=FeatureData(float_features=state_features), action=batch.action, next_state=batch.next_state, reward=batch.reward, time_diff=torch.ones_like(batch.reward).float(), not_terminal=batch.not_terminal, step=None, ) losses = self.trainer.get_loss(new_batch, state_dim=state_dim) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss) del losses self.trainer.memory_network.mdnrnn.train() logger.info("**** Debug tool feature importance ****: {}".format( feature_importance)) return {"feature_loss_increase": feature_importance.numpy()}