コード例 #1
0
 def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput:
     """ Act randomly regardless of the observation. """
     obs: torch.Tensor = obs.float_features
     assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
     batch_size = obs.size(0)
     # pyre-fixme[6]: Expected `Union[torch.Size, torch.Tensor]` for 1st param
     #  but got `Tuple[int]`.
     action = self.dist.sample((batch_size, ))
     log_prob = self.dist.log_prob(action)
     return rlt.ActorOutput(action=action, log_prob=log_prob)
コード例 #2
0
    def evaluate(self, batch: PreprocessedMemoryNetworkInput):
        """ Calculate state feature sensitivity due to actions:
        randomly permutating actions and see how much the prediction of next
        state feature deviates. """
        assert isinstance(batch, PreprocessedMemoryNetworkInput)

        self.trainer.memory_network.mdnrnn.eval()

        seq_len, batch_size, state_dim = batch.next_state.float_features.size()
        state_feature_num = self.state_feature_num
        feature_sensitivity = torch.zeros(state_feature_num)

        # the input of world_model has seq-len as the first dimension
        mdnrnn_output = self.trainer.memory_network(batch.state,
                                                    FeatureData(batch.action))
        predicted_next_state_means = mdnrnn_output.mus

        shuffled_mdnrnn_output = self.trainer.memory_network(
            batch.state,
            # shuffle the actions
            FeatureData(batch.action[:, torch.randperm(batch_size), :]),
        )
        shuffled_predicted_next_state_means = shuffled_mdnrnn_output.mus

        assert (predicted_next_state_means.size() ==
                shuffled_predicted_next_state_means.size() ==
                (seq_len, batch_size, self.trainer.params.num_gaussians,
                 state_dim))

        state_feature_boundaries = self.sorted_state_feature_start_indices + [
            state_dim
        ]
        for i in range(state_feature_num):
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            abs_diff = torch.mean(
                torch.sum(
                    torch.abs(
                        shuffled_predicted_next_state_means[:, :, :,
                                                            boundary_start:
                                                            boundary_end] -
                        predicted_next_state_means[:, :, :,
                                                   boundary_start:boundary_end]
                    ),
                    dim=3,
                ))
            feature_sensitivity[i] = abs_diff.cpu().detach().item()

        self.trainer.memory_network.mdnrnn.train()
        logger.info("**** Debug tool feature sensitivity ****: {}".format(
            feature_sensitivity))
        return {"feature_sensitivity": feature_sensitivity.numpy()}
コード例 #3
0
 def act(
     self, obs: rlt.FeatureData, possible_actions_mask: Optional[np.ndarray] = None
 ) -> rlt.ActorOutput:
     """ Act randomly regardless of the observation. """
     obs: torch.Tensor = obs.float_features
     assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
     batch_size = obs.size(0)
     # pyre-fixme[6]: Expected `Union[torch.Size, torch.Tensor]` for 1st param
     #  but got `Tuple[int]`.
     action = self.dist.sample((batch_size,))
     # sum over action_dim (since assuming i.i.d. per coordinate)
     log_prob = self.dist.log_prob(action).sum(1)
     return rlt.ActorOutput(action=action, log_prob=log_prob)
コード例 #4
0
def get_parametric_input(max_num_actions: int, obs: rlt.FeatureData):
    assert (len(obs.float_features.shape) == 2
            ), f"{obs.float_features.shape} is not (batch_size, state_dim)."
    batch_size, _ = obs.float_features.shape
    possible_actions = get_possible_actions_for_gym(
        batch_size, max_num_actions).to(obs.float_features.device)
    return obs.get_tiled_batch(max_num_actions), possible_actions
コード例 #5
0
ファイル: slate_q_trainer.py プロジェクト: hermes2k/ReAgent
 def _get_unmasked_q_values(self, q_network, state: rlt.FeatureData,
                            slate: rlt.DocList) -> torch.Tensor:
     """ Gets the q values from the model and target networks """
     batch_size, slate_size, _ = slate.float_features.shape
     # TODO: Probably should create a new model type
     return q_network(state.repeat_interleave(slate_size, dim=0),
                      slate.as_feature_data()).view(batch_size, slate_size)
コード例 #6
0
 def _get_unmask_q_values(
     self,
     q_network,
     state: rlt.FeatureData,
     action: rlt.PreprocessedSlateFeatureVector,
 ) -> torch.Tensor:
     batch_size, slate_size, _ = action.float_features.shape
     return q_network(
         state.repeat_interleave(slate_size, dim=0),
         action.as_preprocessed_feature_vector(),
     ).view(batch_size, slate_size)
コード例 #7
0
    def act(self, obs: rlt.FeatureData) -> rlt.ActorOutput:
        """ Act randomly regardless of the observation. """
        obs: torch.Tensor = obs.float_features
        assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
        batch_size = obs.shape[0]
        weights = torch.ones((batch_size, self.num_actions))

        # sample a random action
        m = torch.distributions.Categorical(weights)
        raw_action = m.sample()
        action = F.one_hot(raw_action, self.num_actions)
        log_prob = m.log_prob(raw_action).float()
        return rlt.ActorOutput(action=action, log_prob=log_prob)
コード例 #8
0
    def score(preprocessed_obs: rlt.FeatureData) -> torch.Tensor:
        tiled_state = preprocessed_obs.repeat_interleave(repeats=num_actions,
                                                         axis=0)

        actions = rlt.FeatureData(float_features=torch.eye(num_actions))

        q_network.eval()
        scores = q_network(tiled_state.state, actions).view(-1, num_actions)
        assert (
            scores.size(1) == num_actions
        ), f"scores size is {scores.size(0)}, num_actions is {num_actions}"
        q_network.train()
        return F.log_softmax(scores, dim=-1)
コード例 #9
0
ファイル: slate_q_scorer.py プロジェクト: zachkeer/ReAgent
    def score(state: rlt.FeatureData) -> torch.Tensor:
        tiled_state = state.repeat_interleave(repeats=num_candidates, axis=0)
        candidate_docs = state.candidate_docs
        assert candidate_docs is not None
        actions = candidate_docs.as_feature_data()

        q_network.eval()
        scores = q_network(tiled_state, actions).view(-1, num_candidates)
        q_network.train()

        select_prob = F.softmax(candidate_docs.value, dim=1)
        assert select_prob.shape == scores.shape

        return select_prob * scores
コード例 #10
0
    def act(
        self, obs: rlt.FeatureData, possible_actions_mask: Optional[np.ndarray] = None
    ) -> rlt.ActorOutput:
        """ Act randomly regardless of the observation. """
        obs: torch.Tensor = obs.float_features
        assert obs.dim() >= 2, f"obs has shape {obs.shape} (dim < 2)"
        assert obs.shape[0] == 1, f"obs has shape {obs.shape} (0th dim != 1)"
        batch_size = obs.shape[0]
        scores = torch.ones((batch_size, self.num_actions))
        scores = apply_possible_actions_mask(
            scores, possible_actions_mask, invalid_score=0.0
        )

        # sample a random action
        m = torch.distributions.Categorical(scores)
        raw_action = m.sample()
        action = F.one_hot(raw_action, self.num_actions)
        log_prob = m.log_prob(raw_action).float()
        return rlt.ActorOutput(action=action, log_prob=log_prob)
コード例 #11
0
    def create_from_tensors_dqn(
        cls,
        trainer: DQNTrainer,
        mdp_ids: torch.Tensor,
        sequence_numbers: torch.Tensor,
        states: rlt.FeatureData,
        actions: rlt.FeatureData,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        old_q_train_state = trainer.q_network.training
        # pyre-fixme[16]: `DQNTrainer` has no attribute `reward_network`.
        old_reward_train_state = trainer.reward_network.training
        # pyre-fixme[16]: `DQNTrainer` has no attribute `q_network_cpe`.
        old_q_cpe_train_state = trainer.q_network_cpe.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)
        trainer.q_network_cpe.train(False)

        num_actions = trainer.num_actions
        action_mask = actions.float()

        rewards = trainer.boost_rewards(rewards, actions)
        model_values = trainer.q_network_cpe(states)[:, 0:num_actions]
        optimal_q_values, _ = trainer.get_detached_q_values(states)
        # Do we ever really use eval_action_idxs?
        eval_action_idxs = trainer.get_max_q_values(
            optimal_q_values, possible_actions_mask
        )[1]
        model_propensities = masked_softmax(
            optimal_q_values, possible_actions_mask, trainer.rl_temperature
        )
        assert model_values.shape == actions.shape, (
            "Invalid shape: " + str(model_values.shape) + " != " + str(actions.shape)
        )
        assert model_values.shape == possible_actions_mask.shape, (
            "Invalid shape: "
            + str(model_values.shape)
            + " != "
            + str(possible_actions_mask.shape)
        )
        model_values_for_logged_action = torch.sum(
            model_values * action_mask, dim=1, keepdim=True
        )

        rewards_and_metric_rewards = trainer.reward_network(states)

        # In case we reuse the modular for Q-network
        if hasattr(rewards_and_metric_rewards, "q_values"):
            rewards_and_metric_rewards = rewards_and_metric_rewards

        model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
        assert model_rewards.shape == actions.shape, (
            "Invalid shape: " + str(model_rewards.shape) + " != " + str(actions.shape)
        )
        model_rewards_for_logged_action = torch.sum(
            model_rewards * action_mask, dim=1, keepdim=True
        )

        model_metrics = rewards_and_metric_rewards[:, num_actions:]

        assert model_metrics.shape[1] % num_actions == 0, (
            "Invalid metrics shape: "
            + str(model_metrics.shape)
            + " "
            + str(num_actions)
        )
        num_metrics = model_metrics.shape[1] // num_actions

        if num_metrics == 0:
            model_metrics_values = None
            model_metrics_for_logged_action = None
            model_metrics_values_for_logged_action = None
        else:
            model_metrics_values = trainer.q_network_cpe(states)
            # Backward compatility
            if hasattr(model_metrics_values, "q_values"):
                model_metrics_values = model_metrics_values
            model_metrics_values = model_metrics_values[:, num_actions:]
            assert model_metrics_values.shape[1] == num_actions * num_metrics, (
                "Invalid shape: "
                + str(model_metrics_values.shape[1])
                + " != "
                + str(actions.shape[1] * num_metrics)
            )

            model_metrics_for_logged_action_list = []
            model_metrics_values_for_logged_action_list = []
            for metric_index in range(num_metrics):
                metric_start = metric_index * num_actions
                metric_end = (metric_index + 1) * num_actions
                model_metrics_for_logged_action_list.append(
                    torch.sum(
                        model_metrics[:, metric_start:metric_end] * action_mask,
                        dim=1,
                        keepdim=True,
                    )
                )

                model_metrics_values_for_logged_action_list.append(
                    torch.sum(
                        model_metrics_values[:, metric_start:metric_end] * action_mask,
                        dim=1,
                        keepdim=True,
                    )
                )
            model_metrics_for_logged_action = torch.cat(
                model_metrics_for_logged_action_list, dim=1
            )
            model_metrics_values_for_logged_action = torch.cat(
                model_metrics_values_for_logged_action_list, dim=1
            )

        trainer.q_network_cpe.train(old_q_cpe_train_state)
        trainer.q_network.train(old_q_train_state)
        trainer.reward_network.train(old_reward_train_state)

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
コード例 #12
0
    def evaluate(self, batch: PreprocessedMemoryNetworkInput):
        """ Calculate feature importance: setting each state/action feature to
        the mean value and observe loss increase. """

        self.trainer.memory_network.mdnrnn.eval()
        state_features = batch.state.float_features
        action_features = batch.action
        seq_len, batch_size, state_dim = state_features.size()
        action_dim = action_features.size()[2]
        action_feature_num = self.action_feature_num
        state_feature_num = self.state_feature_num
        feature_importance = torch.zeros(action_feature_num +
                                         state_feature_num)

        orig_losses = self.trainer.get_loss(batch, state_dim=state_dim)
        orig_loss = orig_losses["loss"].cpu().detach().item()
        del orig_losses

        action_feature_boundaries = self.sorted_action_feature_start_indices + [
            action_dim
        ]
        state_feature_boundaries = self.sorted_state_feature_start_indices + [
            state_dim
        ]

        for i in range(action_feature_num):
            action_features = batch.action.reshape(
                (batch_size * seq_len, action_dim)).data.clone()

            # if actions are discrete, an action's feature importance is the loss
            # increase due to setting all actions to this action
            if self.discrete_action:
                assert action_dim == action_feature_num
                action_vec = torch.zeros(action_dim)
                action_vec[i] = 1
                action_features[:] = action_vec
            # if actions are continuous, an action's feature importance is the loss
            # increase due to masking this action feature to its mean value
            else:
                boundary_start, boundary_end = (
                    action_feature_boundaries[i],
                    action_feature_boundaries[i + 1],
                )
                action_features[:, boundary_start:
                                boundary_end] = self.compute_median_feature_value(
                                    action_features[:, boundary_start:
                                                    boundary_end])

            action_features = action_features.reshape(
                (seq_len, batch_size, action_dim))
            new_batch = PreprocessedMemoryNetworkInput(
                state=batch.state,
                action=action_features,
                next_state=batch.next_state,
                reward=batch.reward,
                time_diff=torch.ones_like(batch.reward).float(),
                not_terminal=batch.not_terminal,
                step=None,
            )
            losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
            feature_importance[i] = losses["loss"].cpu().detach().item(
            ) - orig_loss
            del losses

        for i in range(state_feature_num):
            state_features = batch.state.float_features.reshape(
                (batch_size * seq_len, state_dim)).data.clone()
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            state_features[:, boundary_start:
                           boundary_end] = self.compute_median_feature_value(
                               state_features[:, boundary_start:boundary_end])
            state_features = state_features.reshape(
                (seq_len, batch_size, state_dim))
            new_batch = PreprocessedMemoryNetworkInput(
                state=FeatureData(float_features=state_features),
                action=batch.action,
                next_state=batch.next_state,
                reward=batch.reward,
                time_diff=torch.ones_like(batch.reward).float(),
                not_terminal=batch.not_terminal,
                step=None,
            )
            losses = self.trainer.get_loss(new_batch, state_dim=state_dim)
            feature_importance[i + action_feature_num] = (
                losses["loss"].cpu().detach().item() - orig_loss)
            del losses

        self.trainer.memory_network.mdnrnn.train()
        logger.info("**** Debug tool feature importance ****: {}".format(
            feature_importance))
        return {"feature_loss_increase": feature_importance.numpy()}