Beispiel #1
0
    def handle(self, tdp: PreprocessedTrainingBatch) -> None:
        assert isinstance(tdp.training_input, PreprocessedMemoryNetworkInput)
        batch_size, _, _ = tdp.training_input.next_state.float_features.size()

        tdp = PreprocessedTrainingBatch(
            training_input=PreprocessedMemoryNetworkInput(
                state=tdp.training_input.state,
                action=tdp.training_input.action,  # type: ignore
                time_diff=torch.ones_like(
                    tdp.training_input.reward[torch.randperm(batch_size)]
                ).float(),
                # shuffle the data
                next_state=tdp.training_input.next_state._replace(
                    float_features=tdp.training_input.next_state.float_features[
                        torch.randperm(batch_size)
                    ]
                ),
                reward=tdp.training_input.reward[torch.randperm(batch_size)],
                not_terminal=tdp.training_input.not_terminal[  # type: ignore
                    torch.randperm(batch_size)
                ],
                step=None,
            ),
            extras=ExtraData(),
        )
        losses = self.trainer_or_evaluator.train(tdp, batch_first=True)
        self.results.append(losses)
Beispiel #2
0
 def handle(self, tdp: TrainingBatch) -> None:
     batch_size, _, _ = tdp.training_input.next_state.size()
     tdp = TrainingBatch(
         training_input=MemoryNetworkInput(
             state=tdp.training_input.state,
             action=tdp.training_input.action,
             # shuffle the data
             next_state=tdp.training_input.next_state[torch.randperm(
                 batch_size)],
             reward=tdp.training_input.reward[torch.randperm(batch_size)],
             not_terminal=tdp.training_input.not_terminal[torch.randperm(
                 batch_size)],
         ),
         extras=ExtraData(),
     )
     losses = self.trainer_or_evaluator.train(tdp, batch_first=True)
     self.results.append(losses)
    def preprocess(self, batch) -> TrainingBatch:
        state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor(
            batch["state_features"])
        next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor(
            batch["next_state_features"])

        mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1)
        sequence_numbers = torch.tensor(batch["sequence_number"],
                                        dtype=torch.int32).reshape(-1, 1)
        rewards = torch.tensor(batch["reward"],
                               dtype=torch.float32).reshape(-1, 1)
        time_diffs = torch.tensor(batch["time_diff"],
                                  dtype=torch.int32).reshape(-1, 1)
        if "action_probability" in batch:
            propensities = torch.tensor(batch["action_probability"],
                                        dtype=torch.float32).reshape(-1, 1)
        else:
            propensities = torch.ones(rewards.shape, dtype=torch.float32)

        return TrainingBatch(
            training_input=BaseInput(
                state=FeatureVector(float_features=ValuePresence(
                    value=state_features_dense,
                    presence=state_features_dense_presence,
                )),
                next_state=FeatureVector(float_features=ValuePresence(
                    value=next_state_features_dense,
                    presence=next_state_features_dense_presence,
                )),
                reward=rewards,
                time_diff=time_diffs,
            ),
            extras=ExtraData(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                action_probability=propensities,
            ),
        )
Beispiel #4
0
    def evaluate(self, tdp: TrainingBatch):
        """ Calculate feature importance: setting each state/action feature to
        the mean value and observe loss increase. """
        self.trainer.mdnrnn.mdnrnn.eval()

        state_features = tdp.training_input.state.float_features
        action_features = tdp.training_input.action.float_features  # type: ignore
        batch_size, seq_len, state_dim = state_features.size()  # type: ignore
        action_dim = action_features.size()[2]  # type: ignore
        action_feature_num = self.action_feature_num
        state_feature_num = self.state_feature_num
        feature_importance = torch.zeros(action_feature_num + state_feature_num)

        orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True)
        orig_loss = orig_losses["loss"].cpu().detach().item()
        del orig_losses

        action_feature_boundaries = self.sorted_action_feature_start_indices + [
            action_dim
        ]
        state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim]

        for i in range(action_feature_num):
            action_features = tdp.training_input.action.float_features.reshape(  # type: ignore
                (batch_size * seq_len, action_dim)
            ).data.clone()

            # if actions are discrete, an action's feature importance is the loss
            # increase due to setting all actions to this action
            if self.discrete_action:
                assert action_dim == action_feature_num
                action_vec = torch.zeros(action_dim)
                action_vec[i] = 1
                action_features[:] = action_vec  # type: ignore
            # if actions are continuous, an action's feature importance is the loss
            # increase due to masking this action feature to its mean value
            else:
                boundary_start, boundary_end = (
                    action_feature_boundaries[i],
                    action_feature_boundaries[i + 1],
                )
                action_features[  # type: ignore
                    :, boundary_start:boundary_end
                ] = self.compute_median_feature_value(  # type: ignore
                    action_features[:, boundary_start:boundary_end]  # type: ignore
                )

            action_features = action_features.reshape(  # type: ignore
                (batch_size, seq_len, action_dim)
            )  # type: ignore
            new_tdp = TrainingBatch(
                training_input=MemoryNetworkInput(  # type: ignore
                    state=tdp.training_input.state,
                    action=FeatureVector(  # type: ignore
                        float_features=action_features
                    ),  # type: ignore
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    not_terminal=tdp.training_input.not_terminal,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(
                new_tdp, state_dim=state_dim, batch_first=True
            )
            feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss
            del losses

        for i in range(state_feature_num):
            state_features = tdp.training_input.state.float_features.reshape(  # type: ignore
                (batch_size * seq_len, state_dim)
            ).data.clone()
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            state_features[  # type: ignore
                :, boundary_start:boundary_end
            ] = self.compute_median_feature_value(
                state_features[:, boundary_start:boundary_end]  # type: ignore
            )
            state_features = state_features.reshape(  # type: ignore
                (batch_size, seq_len, state_dim)
            )  # type: ignore
            new_tdp = TrainingBatch(
                training_input=MemoryNetworkInput(  # type: ignore
                    state=FeatureVector(float_features=state_features),  # type: ignore
                    action=tdp.training_input.action,
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    not_terminal=tdp.training_input.not_terminal,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(
                new_tdp, state_dim=state_dim, batch_first=True
            )
            feature_importance[i + action_feature_num] = (
                losses["loss"].cpu().detach().item() - orig_loss
            )
            del losses

        self.trainer.mdnrnn.mdnrnn.train()
        logger.info(
            "**** Debug tool feature importance ****: {}".format(feature_importance)
        )
        return {"feature_loss_increase": feature_importance.numpy()}
    def evaluate(self, tdp: TrainingBatch):
        """ Calculate feature importance: setting each state/action feature to
        the mean value and observe loss increase. """
        self.trainer.mdnrnn.mdnrnn.eval()

        state_features = tdp.training_input.state.float_features
        action_features = tdp.training_input.action.float_features
        batch_size, seq_len, state_dim = state_features.size()
        action_dim = action_features.size()[2]
        action_feature_num = self.feature_extractor.action_feature_num
        state_feature_num = self.feature_extractor.state_feature_num
        feature_importance = torch.zeros(action_feature_num +
                                         state_feature_num)

        orig_losses = self.trainer.get_loss(tdp,
                                            state_dim=state_dim,
                                            batch_first=True)
        orig_loss = orig_losses["loss"].cpu().detach().item()
        del orig_losses

        action_feature_boundaries = (
            self.feature_extractor.sorted_action_feature_start_indices +
            [action_dim])
        state_feature_boundaries = (
            self.feature_extractor.sorted_state_feature_start_indices +
            [state_dim])

        for i in range(action_feature_num):
            action_features = tdp.training_input.action.float_features.reshape(
                (batch_size * seq_len, action_dim)).data.clone()
            boundary_start, boundary_end = (
                action_feature_boundaries[i],
                action_feature_boundaries[i + 1],
            )
            action_features[:, boundary_start:
                            boundary_end] = action_features[:, boundary_start:
                                                            boundary_end].mean(
                                                                dim=0)
            action_features = action_features.reshape(
                (batch_size, seq_len, action_dim))
            new_tdp = TrainingBatch(
                training_input=MemoryNetworkInput(
                    state=tdp.training_input.state,
                    action=FeatureVector(float_features=action_features),
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    not_terminal=tdp.training_input.not_terminal,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(new_tdp,
                                           state_dim=state_dim,
                                           batch_first=True)
            feature_importance[i] = losses["loss"].cpu().detach().item(
            ) - orig_loss
            del losses

        for i in range(state_feature_num):
            state_features = tdp.training_input.state.float_features.reshape(
                (batch_size * seq_len, state_dim)).data.clone()
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            state_features[:, boundary_start:
                           boundary_end] = state_features[:, boundary_start:
                                                          boundary_end].mean(
                                                              dim=0)
            state_features = state_features.reshape(
                (batch_size, seq_len, state_dim))
            new_tdp = TrainingBatch(
                training_input=MemoryNetworkInput(
                    state=FeatureVector(float_features=state_features),
                    action=tdp.training_input.action,
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    not_terminal=tdp.training_input.not_terminal,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(new_tdp,
                                           state_dim=state_dim,
                                           batch_first=True)
            feature_importance[i + action_feature_num] = (
                losses["loss"].cpu().detach().item() - orig_loss)
            del losses

        self.trainer.mdnrnn.mdnrnn.train()
        logger.info("**** Debug tool feature importance ****: {}".format(
            feature_importance))
        return {"feature_loss_increase": feature_importance.numpy()}