def create_from_tensors_dqn(
        cls,
        trainer: DQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: rlt.PreprocessedFeatureVector,
        actions: rlt.PreprocessedFeatureVector,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        metrics: Optional[torch.Tensor] = None,
    ):
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        old_q_cpe_train_state = trainer.q_network_cpe.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)
        trainer.q_network_cpe.train(False)

        num_actions = trainer.num_actions
        action_mask = actions.float()  # type: ignore

        rewards = trainer.boost_rewards(rewards, actions)  # type: ignore
        model_values = trainer.q_network_cpe(
            rlt.PreprocessedState(state=states)
        ).q_values[:, 0:num_actions]
        optimal_q_values, _ = trainer.get_detached_q_values(
            states  # type: ignore
        )
        eval_action_idxs = trainer.get_max_q_values(  # type: ignore
            optimal_q_values, possible_actions_mask
        )[1]
        model_propensities = masked_softmax(
            optimal_q_values, possible_actions_mask, trainer.rl_temperature
        )
        assert model_values.shape == actions.shape, (  # type: ignore
            "Invalid shape: "
            + str(model_values.shape)  # type: ignore
            + " != "
            + str(actions.shape)  # type: ignore
        )
        assert model_values.shape == possible_actions_mask.shape, (  # type: ignore
            "Invalid shape: "
            + str(model_values.shape)  # type: ignore
            + " != "
            + str(possible_actions_mask.shape)  # type: ignore
        )
        model_values_for_logged_action = torch.sum(
            model_values * action_mask, dim=1, keepdim=True
        )

        rewards_and_metric_rewards = trainer.reward_network(
            rlt.PreprocessedState(state=states)
        )

        # In case we reuse the modular for Q-network
        if hasattr(rewards_and_metric_rewards, "q_values"):
            rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

        model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
        assert model_rewards.shape == actions.shape, (  # type: ignore
            "Invalid shape: "
            + str(model_rewards.shape)  # type: ignore
            + " != "
            + str(actions.shape)  # type: ignore
        )
        model_rewards_for_logged_action = torch.sum(
            model_rewards * action_mask, dim=1, keepdim=True
        )

        model_metrics = rewards_and_metric_rewards[:, num_actions:]

        assert model_metrics.shape[1] % num_actions == 0, (
            "Invalid metrics shape: "
            + str(model_metrics.shape)
            + " "
            + str(num_actions)
        )
        num_metrics = model_metrics.shape[1] // num_actions

        if num_metrics == 0:
            model_metrics_values = None
            model_metrics_for_logged_action = None
            model_metrics_values_for_logged_action = None
        else:
            model_metrics_values = trainer.q_network_cpe(
                rlt.PreprocessedState(state=states)
            )
            # Backward compatility
            if hasattr(model_metrics_values, "q_values"):
                model_metrics_values = model_metrics_values.q_values
            model_metrics_values = model_metrics_values[:, num_actions:]
            assert (
                model_metrics_values.shape[1] == num_actions * num_metrics
            ), (  # type: ignore
                "Invalid shape: "
                + str(model_metrics_values.shape[1])  # type: ignore
                + " != "
                + str(actions.shape[1] * num_metrics)  # type: ignore
            )

            model_metrics_for_logged_action_list = []
            model_metrics_values_for_logged_action_list = []
            for metric_index in range(num_metrics):
                metric_start = metric_index * num_actions
                metric_end = (metric_index + 1) * num_actions
                model_metrics_for_logged_action_list.append(
                    torch.sum(
                        model_metrics[:, metric_start:metric_end] * action_mask,
                        dim=1,
                        keepdim=True,
                    )
                )

                model_metrics_values_for_logged_action_list.append(
                    torch.sum(
                        model_metrics_values[:, metric_start:metric_end] * action_mask,
                        dim=1,
                        keepdim=True,
                    )
                )
            model_metrics_for_logged_action = torch.cat(
                model_metrics_for_logged_action_list, dim=1
            )
            model_metrics_values_for_logged_action = torch.cat(
                model_metrics_values_for_logged_action_list, dim=1
            )

        trainer.q_network_cpe.train(old_q_cpe_train_state)  # type: ignore
        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
    def evaluate(self, tdp: PreprocessedTrainingBatch):
        """ Calculate state feature sensitivity due to actions:
        randomly permutating actions and see how much the prediction of next
        state feature deviates. """
        mdnrnn_training_input = tdp.training_input
        assert isinstance(mdnrnn_training_input,
                          PreprocessedMemoryNetworkInput)

        self.trainer.mdnrnn.mdnrnn.eval()

        batch_size, seq_len, state_dim = (
            mdnrnn_training_input.next_state.float_features.size())
        state_feature_num = self.state_feature_num
        feature_sensitivity = torch.zeros(state_feature_num)

        state, action, next_state, reward, not_terminal = transpose(
            mdnrnn_training_input.state.float_features,
            mdnrnn_training_input.action,
            mdnrnn_training_input.next_state.float_features,
            mdnrnn_training_input.reward,
            mdnrnn_training_input.not_terminal,
        )
        mdnrnn_input = PreprocessedStateAction(
            state=PreprocessedFeatureVector(float_features=state),
            action=PreprocessedFeatureVector(float_features=action),
        )
        # the input of mdnrnn has seq-len as the first dimension
        mdnrnn_output = self.trainer.mdnrnn(mdnrnn_input)
        predicted_next_state_means = mdnrnn_output.mus

        shuffled_mdnrnn_input = PreprocessedStateAction(
            state=PreprocessedFeatureVector(float_features=state),
            # shuffle the actions
            action=PreprocessedFeatureVector(
                float_features=action[:, torch.randperm(batch_size), :]),
        )
        shuffled_mdnrnn_output = self.trainer.mdnrnn(shuffled_mdnrnn_input)
        shuffled_predicted_next_state_means = shuffled_mdnrnn_output.mus

        assert (predicted_next_state_means.size() ==
                shuffled_predicted_next_state_means.size() ==
                (seq_len, batch_size, self.trainer.params.num_gaussians,
                 state_dim))

        state_feature_boundaries = self.sorted_state_feature_start_indices + [
            state_dim
        ]
        for i in range(state_feature_num):
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            abs_diff = torch.mean(
                torch.sum(
                    torch.abs(
                        shuffled_predicted_next_state_means[:, :, :,
                                                            boundary_start:
                                                            boundary_end] -
                        predicted_next_state_means[:, :, :,
                                                   boundary_start:boundary_end]
                    ),
                    dim=3,
                ))
            feature_sensitivity[i] = abs_diff.cpu().detach().item()

        self.trainer.mdnrnn.mdnrnn.train()
        logger.info("**** Debug tool feature sensitivity ****: {}".format(
            feature_sensitivity))
        return {"feature_sensitivity": feature_sensitivity.numpy()}
    def evaluate(self, tdp: PreprocessedTrainingBatch):
        """ Calculate feature importance: setting each state/action feature to
        the mean value and observe loss increase. """
        assert isinstance(tdp.training_input, PreprocessedMemoryNetworkInput)

        self.trainer.mdnrnn.mdnrnn.eval()

        state_features = tdp.training_input.state.float_features
        action_features = tdp.training_input.action  # type: ignore
        batch_size, seq_len, state_dim = state_features.size()  # type: ignore
        action_dim = action_features.size()[2]  # type: ignore
        action_feature_num = self.action_feature_num
        state_feature_num = self.state_feature_num
        feature_importance = torch.zeros(action_feature_num +
                                         state_feature_num)

        orig_losses = self.trainer.get_loss(tdp,
                                            state_dim=state_dim,
                                            batch_first=True)
        orig_loss = orig_losses["loss"].cpu().detach().item()
        del orig_losses

        action_feature_boundaries = self.sorted_action_feature_start_indices + [
            action_dim
        ]
        state_feature_boundaries = self.sorted_state_feature_start_indices + [
            state_dim
        ]

        for i in range(action_feature_num):
            action_features = tdp.training_input.action.reshape(  # type: ignore
                (batch_size * seq_len, action_dim)).data.clone()

            # if actions are discrete, an action's feature importance is the loss
            # increase due to setting all actions to this action
            if self.discrete_action:
                assert action_dim == action_feature_num
                action_vec = torch.zeros(action_dim)
                action_vec[i] = 1
                action_features[:] = action_vec  # type: ignore
            # if actions are continuous, an action's feature importance is the loss
            # increase due to masking this action feature to its mean value
            else:
                boundary_start, boundary_end = (
                    action_feature_boundaries[i],
                    action_feature_boundaries[i + 1],
                )
                action_features[  # type: ignore
                    :, boundary_start:
                    boundary_end] = self.compute_median_feature_value(  # type: ignore
                        action_features[:, boundary_start:
                                        boundary_end]  # type: ignore
                    )

            action_features = action_features.reshape(  # type: ignore
                (batch_size, seq_len, action_dim))  # type: ignore
            new_tdp = PreprocessedTrainingBatch(
                training_input=PreprocessedMemoryNetworkInput(  # type: ignore
                    state=tdp.training_input.state,
                    action=action_features,
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    time_diff=torch.ones_like(
                        tdp.training_input.reward).float(),
                    not_terminal=tdp.training_input.
                    not_terminal,  # type: ignore
                    step=None,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(new_tdp,
                                           state_dim=state_dim,
                                           batch_first=True)
            feature_importance[i] = losses["loss"].cpu().detach().item(
            ) - orig_loss
            del losses

        for i in range(state_feature_num):
            state_features = tdp.training_input.state.float_features.reshape(  # type: ignore
                (batch_size * seq_len, state_dim)).data.clone()
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            state_features[  # type: ignore
                :, boundary_start:
                boundary_end] = self.compute_median_feature_value(
                    state_features[:,
                                   boundary_start:boundary_end]  # type: ignore
                )
            state_features = state_features.reshape(  # type: ignore
                (batch_size, seq_len, state_dim))  # type: ignore
            new_tdp = PreprocessedTrainingBatch(
                training_input=PreprocessedMemoryNetworkInput(  # type: ignore
                    state=PreprocessedFeatureVector(
                        float_features=state_features),
                    action=tdp.training_input.action,  # type: ignore
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    time_diff=torch.ones_like(
                        tdp.training_input.reward).float(),
                    not_terminal=tdp.training_input.
                    not_terminal,  # type: ignore
                    step=None,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(new_tdp,
                                           state_dim=state_dim,
                                           batch_first=True)
            feature_importance[i + action_feature_num] = (
                losses["loss"].cpu().detach().item() - orig_loss)
            del losses

        self.trainer.mdnrnn.mdnrnn.train()
        logger.info("**** Debug tool feature importance ****: {}".format(
            feature_importance))
        return {"feature_loss_increase": feature_importance.numpy()}
Example #4
0
    def create_from_tensors(
        cls,
        trainer: DQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: rlt.PreprocessedFeatureVector,
        actions: rlt.PreprocessedFeatureVector,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: Optional[rlt.PreprocessedFeatureVector] = None,
        max_num_actions: Optional[int] = None,
        metrics: Optional[torch.Tensor] = None,
    ):
        # Switch to evaluation mode for the network
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)

        if max_num_actions:
            # Parametric model CPE
            state_action_pairs = rlt.PreprocessedStateAction(
                state=states, action=actions
            )
            tiled_state = states.float_features.repeat(1, max_num_actions).reshape(
                -1, states.float_features.shape[1]
            )
            assert possible_actions is not None
            # Get Q-value of action taken
            possible_actions_state_concat = rlt.PreprocessedStateAction(
                state=rlt.PreprocessedFeatureVector(float_features=tiled_state),
                action=possible_actions,
            )

            # Parametric actions
            # FIXME: model_values and model propensities should be calculated
            # as in discrete dqn model
            model_values = trainer.q_network(
                possible_actions_state_concat
            ).q_value  # type: ignore
            optimal_q_values = model_values
            eval_action_idxs = None

            assert (
                model_values.shape[0] * model_values.shape[1]
                == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
            ), (
                "Invalid shapes: "
                + str(model_values.shape)
                + " != "
                + str(possible_actions_mask.shape)
            )
            model_values = model_values.reshape(possible_actions_mask.shape)
            model_propensities = masked_softmax(
                model_values, possible_actions_mask, trainer.rl_temperature
            )

            model_rewards = trainer.reward_network(
                possible_actions_state_concat
            ).q_value  # type: ignore
            assert (
                model_rewards.shape[0] * model_rewards.shape[1]
                == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
            ), (
                "Invalid shapes: "
                + str(model_rewards.shape)
                + " != "
                + str(possible_actions_mask.shape)
            )
            model_rewards = model_rewards.reshape(possible_actions_mask.shape)

            model_values_for_logged_action = trainer.q_network(
                state_action_pairs
            ).q_value
            model_rewards_for_logged_action = trainer.reward_network(
                state_action_pairs
            ).q_value

            action_mask = (
                torch.abs(model_values - model_values_for_logged_action) < 1e-3
            ).float()

            model_metrics = None
            model_metrics_for_logged_action = None
            model_metrics_values = None
            model_metrics_values_for_logged_action = None
        else:
            num_actions = trainer.num_actions
            action_mask = actions.float()  # type: ignore

            # Switch to evaluation mode for the network
            old_q_cpe_train_state = trainer.q_network_cpe.training
            trainer.q_network_cpe.train(False)

            # Discrete actions
            rewards = trainer.boost_rewards(rewards, actions)  # type: ignore
            model_values = trainer.q_network_cpe(
                rlt.PreprocessedState(state=states)
            ).q_values[:, 0:num_actions]
            optimal_q_values = trainer.get_detached_q_values(
                states  # type: ignore
            )[  # type: ignore
                0
            ]  # type: ignore
            eval_action_idxs = trainer.get_max_q_values(  # type: ignore
                optimal_q_values, possible_actions_mask
            )[1]
            model_propensities = masked_softmax(
                optimal_q_values, possible_actions_mask, trainer.rl_temperature
            )
            assert model_values.shape == actions.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_values.shape)  # type: ignore
                + " != "
                + str(actions.shape)  # type: ignore
            )
            assert model_values.shape == possible_actions_mask.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_values.shape)  # type: ignore
                + " != "
                + str(possible_actions_mask.shape)  # type: ignore
            )
            model_values_for_logged_action = torch.sum(
                model_values * action_mask, dim=1, keepdim=True
            )

            rewards_and_metric_rewards = trainer.reward_network(
                rlt.PreprocessedState(state=states)
            )

            # In case we reuse the modular for Q-network
            if hasattr(rewards_and_metric_rewards, "q_values"):
                rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

            model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
            assert model_rewards.shape == actions.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_rewards.shape)  # type: ignore
                + " != "
                + str(actions.shape)  # type: ignore
            )
            model_rewards_for_logged_action = torch.sum(
                model_rewards * action_mask, dim=1, keepdim=True
            )

            model_metrics = rewards_and_metric_rewards[:, num_actions:]

            assert model_metrics.shape[1] % num_actions == 0, (
                "Invalid metrics shape: "
                + str(model_metrics.shape)
                + " "
                + str(num_actions)
            )
            num_metrics = model_metrics.shape[1] // num_actions

            if num_metrics == 0:
                model_metrics_values = None
                model_metrics_for_logged_action = None
                model_metrics_values_for_logged_action = None
            else:
                model_metrics_values = trainer.q_network_cpe(
                    rlt.PreprocessedState(state=states)
                )
                # Backward compatility
                if hasattr(model_metrics_values, "q_values"):
                    model_metrics_values = model_metrics_values.q_values
                model_metrics_values = model_metrics_values[:, num_actions:]
                assert (
                    model_metrics_values.shape[1] == num_actions * num_metrics
                ), (  # type: ignore
                    "Invalid shape: "
                    + str(model_metrics_values.shape[1])  # type: ignore
                    + " != "
                    + str(actions.shape[1] * num_metrics)  # type: ignore
                )

                model_metrics_for_logged_action_list = []
                model_metrics_values_for_logged_action_list = []
                for metric_index in range(num_metrics):
                    metric_start = metric_index * num_actions
                    metric_end = (metric_index + 1) * num_actions
                    model_metrics_for_logged_action_list.append(
                        torch.sum(
                            model_metrics[:, metric_start:metric_end] * action_mask,
                            dim=1,
                            keepdim=True,
                        )
                    )

                    model_metrics_values_for_logged_action_list.append(
                        torch.sum(
                            model_metrics_values[:, metric_start:metric_end]
                            * action_mask,
                            dim=1,
                            keepdim=True,
                        )
                    )
                model_metrics_for_logged_action = torch.cat(
                    model_metrics_for_logged_action_list, dim=1
                )
                model_metrics_values_for_logged_action = torch.cat(
                    model_metrics_values_for_logged_action_list, dim=1
                )

            # Switch back to the old mode
            trainer.q_network_cpe.train(old_q_cpe_train_state)  # type: ignore

        # Switch back to the old mode
        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )