Beispiel #1
0
 def handle(self, tdp: TrainingBatch) -> None:
     batch_size, _, _ = tdp.training_input.next_state.size()
     tdp = TrainingBatch(
         training_input=MemoryNetworkInput(
             state=tdp.training_input.state,
             action=tdp.training_input.action,
             # shuffle the data
             next_state=tdp.training_input.next_state[torch.randperm(
                 batch_size)],
             reward=tdp.training_input.reward[torch.randperm(batch_size)],
             not_terminal=tdp.training_input.not_terminal[torch.randperm(
                 batch_size)],
         ),
         extras=ExtraData(),
     )
     losses = self.trainer_or_evaluator.train(tdp, batch_first=True)
     self.results.append(losses)
    def __call__(self, batch: TrainingBatch) -> TrainingBatch:
        training_input = cast(Union[DiscreteDqnInput, ParametricDqnInput],
                              batch.training_input)

        preprocessed_state = self.state_preprocessor(
            training_input.state.float_features.value,
            training_input.state.float_features.presence,
        )
        preprocessed_next_state = self.state_preprocessor(
            training_input.next_state.float_features.value,
            training_input.next_state.float_features.presence,
        )
        new_training_input = training_input._replace(
            state=training_input.state._replace(
                float_features=preprocessed_state),
            next_state=training_input.next_state._replace(
                float_features=preprocessed_next_state),
        )
        return batch._replace(training_input=new_training_input)
    def preprocess(self, batch) -> TrainingBatch:
        state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor(
            batch["state_features"])
        next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor(
            batch["next_state_features"])

        mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1)
        sequence_numbers = torch.tensor(batch["sequence_number"],
                                        dtype=torch.int32).reshape(-1, 1)
        rewards = torch.tensor(batch["reward"],
                               dtype=torch.float32).reshape(-1, 1)
        time_diffs = torch.tensor(batch["time_diff"],
                                  dtype=torch.int32).reshape(-1, 1)
        if "action_probability" in batch:
            propensities = torch.tensor(batch["action_probability"],
                                        dtype=torch.float32).reshape(-1, 1)
        else:
            propensities = torch.ones(rewards.shape, dtype=torch.float32)

        return TrainingBatch(
            training_input=BaseInput(
                state=FeatureVector(float_features=ValuePresence(
                    value=state_features_dense,
                    presence=state_features_dense_presence,
                )),
                next_state=FeatureVector(float_features=ValuePresence(
                    value=next_state_features_dense,
                    presence=next_state_features_dense_presence,
                )),
                reward=rewards,
                time_diff=time_diffs,
            ),
            extras=ExtraData(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                action_probability=propensities,
            ),
        )
    def __call__(self, batch: TrainingBatch) -> TrainingBatch:
        batch = super().__call__(batch)

        training_input = cast(PolicyNetworkInput, batch.training_input)

        action_before_preprocessing = cast(FeatureVector,
                                           training_input.action)
        preprocessed_action = self.action_preprocessor(
            action_before_preprocessing.float_features.value,
            action_before_preprocessing.float_features.presence,
        )
        next_action_before_preprocessing = cast(FeatureVector,
                                                training_input.next_action)
        preprocessed_next_action = self.action_preprocessor(
            next_action_before_preprocessing.float_features.value,
            next_action_before_preprocessing.float_features.presence,
        )
        return batch._replace(training_input=training_input._replace(
            action=action_before_preprocessing._replace(
                float_features=preprocessed_action),
            next_action=next_action_before_preprocessing._replace(
                float_features=preprocessed_next_action),
        ))
Beispiel #5
0
    def train(self, training_batch: rlt.TrainingBatch) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state

        # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh
        action = rlt.FeatureVector(
            rescale_torch_tensor(
                learning_input.action.float_features,
                new_min=self.min_action_range_tensor_training,
                new_max=self.max_action_range_tensor_training,
                prev_min=self.min_action_range_tensor_serving,
                prev_max=self.max_action_range_tensor_serving,
            )
        )

        rewards = learning_input.reward
        next_state = learning_input.next_state
        time_diffs = learning_input.time_diff
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = learning_input.not_terminal

        # Optimize the critic network subject to mean squared error:
        # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2
        q_s1_a1 = self.critic.forward(
            rlt.StateAction(state=state, action=action)
        ).q_value
        next_action = rlt.FeatureVector(
            float_features=self.actor_target(
                rlt.StateAction(state=next_state, action=None)
            ).action
        )

        q_s2_a2 = self.critic_target.forward(
            rlt.StateAction(state=next_state, action=next_action)
        ).q_value
        filtered_q_s2_a2 = not_done_mask.float() * q_s2_a2

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        target_q_values = rewards + (discount_tensor * filtered_q_s2_a2)

        # compute loss and update the critic network
        critic_predictions = q_s1_a1
        loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach())
        loss_critic_for_eval = loss_critic.detach()
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Optimize the actor network subject to the following:
        # max mean(Q(s1, a1)) or min -mean(Q(s1, a1))
        actor_output = self.actor(rlt.StateAction(state=state, action=None))
        loss_actor = -(
            self.critic.forward(
                rlt.StateAction(
                    state=state,
                    action=rlt.FeatureVector(float_features=actor_output.action),
                )
            ).q_value.mean()
        )

        # Zero out both the actor and critic gradients because we need
        #   to backprop through the critic to get to the actor
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # Use the soft update rule to update both target networks
        self._soft_update(self.actor, self.actor_target, self.tau)
        self._soft_update(self.critic, self.critic_target, self.tau)

        self.loss_reporter.report(
            td_loss=float(loss_critic_for_eval),
            reward_loss=None,
            model_values_on_logged_actions=critic_predictions,
        )
Beispiel #6
0
    def evaluate(self, tdp: TrainingBatch):
        """ Calculate feature importance: setting each state/action feature to
        the mean value and observe loss increase. """
        self.trainer.mdnrnn.mdnrnn.eval()

        state_features = tdp.training_input.state.float_features
        action_features = tdp.training_input.action.float_features  # type: ignore
        batch_size, seq_len, state_dim = state_features.size()  # type: ignore
        action_dim = action_features.size()[2]  # type: ignore
        action_feature_num = self.action_feature_num
        state_feature_num = self.state_feature_num
        feature_importance = torch.zeros(action_feature_num + state_feature_num)

        orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True)
        orig_loss = orig_losses["loss"].cpu().detach().item()
        del orig_losses

        action_feature_boundaries = self.sorted_action_feature_start_indices + [
            action_dim
        ]
        state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim]

        for i in range(action_feature_num):
            action_features = tdp.training_input.action.float_features.reshape(  # type: ignore
                (batch_size * seq_len, action_dim)
            ).data.clone()

            # if actions are discrete, an action's feature importance is the loss
            # increase due to setting all actions to this action
            if self.discrete_action:
                assert action_dim == action_feature_num
                action_vec = torch.zeros(action_dim)
                action_vec[i] = 1
                action_features[:] = action_vec  # type: ignore
            # if actions are continuous, an action's feature importance is the loss
            # increase due to masking this action feature to its mean value
            else:
                boundary_start, boundary_end = (
                    action_feature_boundaries[i],
                    action_feature_boundaries[i + 1],
                )
                action_features[  # type: ignore
                    :, boundary_start:boundary_end
                ] = self.compute_median_feature_value(  # type: ignore
                    action_features[:, boundary_start:boundary_end]  # type: ignore
                )

            action_features = action_features.reshape(  # type: ignore
                (batch_size, seq_len, action_dim)
            )  # type: ignore
            new_tdp = TrainingBatch(
                training_input=MemoryNetworkInput(  # type: ignore
                    state=tdp.training_input.state,
                    action=FeatureVector(  # type: ignore
                        float_features=action_features
                    ),  # type: ignore
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    not_terminal=tdp.training_input.not_terminal,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(
                new_tdp, state_dim=state_dim, batch_first=True
            )
            feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss
            del losses

        for i in range(state_feature_num):
            state_features = tdp.training_input.state.float_features.reshape(  # type: ignore
                (batch_size * seq_len, state_dim)
            ).data.clone()
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            state_features[  # type: ignore
                :, boundary_start:boundary_end
            ] = self.compute_median_feature_value(
                state_features[:, boundary_start:boundary_end]  # type: ignore
            )
            state_features = state_features.reshape(  # type: ignore
                (batch_size, seq_len, state_dim)
            )  # type: ignore
            new_tdp = TrainingBatch(
                training_input=MemoryNetworkInput(  # type: ignore
                    state=FeatureVector(float_features=state_features),  # type: ignore
                    action=tdp.training_input.action,
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    not_terminal=tdp.training_input.not_terminal,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(
                new_tdp, state_dim=state_dim, batch_first=True
            )
            feature_importance[i + action_feature_num] = (
                losses["loss"].cpu().detach().item() - orig_loss
            )
            del losses

        self.trainer.mdnrnn.mdnrnn.train()
        logger.info(
            "**** Debug tool feature importance ****: {}".format(feature_importance)
        )
        return {"feature_loss_increase": feature_importance.numpy()}
    def __call__(self, batch: TrainingBatch) -> TrainingBatch:
        batch = super().__call__(batch)

        if isinstance(batch.training_input, ParametricDqnInput):
            training_input = cast(ParametricDqnInput, batch.training_input)
            preprocessed_tiled_next_state = self.state_preprocessor(
                training_input.tiled_next_state.float_features.value,
                training_input.tiled_next_state.float_features.presence,
            )
            preprocessed_action = self.action_preprocessor(
                training_input.action.float_features.value,
                training_input.action.float_features.presence,
            )
            preprocessed_next_action = self.action_preprocessor(
                training_input.next_action.float_features.value,
                training_input.next_action.float_features.presence,
            )
            preprocessed_possible_actions = self.action_preprocessor(
                training_input.possible_actions.float_features.value,
                training_input.possible_actions.float_features.presence,
            )
            preprocessed_possible_next_actions = self.action_preprocessor(
                training_input.possible_next_actions.float_features.value,
                training_input.possible_next_actions.float_features.presence,
            )
            return batch._replace(training_input=training_input._replace(
                action=training_input.action._replace(
                    float_features=preprocessed_action),
                next_action=training_input.next_action._replace(
                    float_features=preprocessed_next_action),
                possible_actions=training_input.possible_actions._replace(
                    float_features=preprocessed_possible_actions),
                possible_next_actions=training_input.possible_next_actions.
                _replace(float_features=preprocessed_possible_next_actions),
                tiled_next_state=training_input.tiled_next_state._replace(
                    float_features=preprocessed_tiled_next_state),
            ))
        elif isinstance(batch.training_input, SARSAInput):
            training_input_sarsa = cast(SARSAInput, batch.training_input)
            preprocessed_tiled_next_state = self.state_preprocessor(
                training_input_sarsa.tiled_next_state.float_features.
                value,  # type: ignore
                training_input_sarsa.tiled_next_state.float_features.
                presence,  # type: ignore
            )
            preprocessed_action = self.action_preprocessor(
                training_input_sarsa.action.float_features.
                value,  # type: ignore
                training_input_sarsa.action.float_features.
                presence,  # type: ignore
            )
            preprocessed_next_action = self.action_preprocessor(
                training_input_sarsa.next_action.float_features.
                value,  # type: ignore
                training_input_sarsa.next_action.float_features.
                presence,  # type: ignore
            )
            return batch._replace(training_input=training_input_sarsa._replace(
                action=training_input_sarsa.action._replace(  # type: ignore
                    float_features=preprocessed_action),
                next_action=training_input_sarsa.next_action.
                _replace(  # type: ignore
                    float_features=preprocessed_next_action),
                tiled_next_state=training_input_sarsa.tiled_next_state.
                _replace(  # type: ignore
                    float_features=preprocessed_tiled_next_state),
            ))
        else:
            assert False, "Invalid training_input type: " + str(
                type(batch.training_input))
    def evaluate(self, tdp: TrainingBatch):
        """ Calculate feature importance: setting each state/action feature to
        the mean value and observe loss increase. """
        self.trainer.mdnrnn.mdnrnn.eval()

        state_features = tdp.training_input.state.float_features
        action_features = tdp.training_input.action.float_features
        batch_size, seq_len, state_dim = state_features.size()
        action_dim = action_features.size()[2]
        action_feature_num = self.feature_extractor.action_feature_num
        state_feature_num = self.feature_extractor.state_feature_num
        feature_importance = torch.zeros(action_feature_num +
                                         state_feature_num)

        orig_losses = self.trainer.get_loss(tdp,
                                            state_dim=state_dim,
                                            batch_first=True)
        orig_loss = orig_losses["loss"].cpu().detach().item()
        del orig_losses

        action_feature_boundaries = (
            self.feature_extractor.sorted_action_feature_start_indices +
            [action_dim])
        state_feature_boundaries = (
            self.feature_extractor.sorted_state_feature_start_indices +
            [state_dim])

        for i in range(action_feature_num):
            action_features = tdp.training_input.action.float_features.reshape(
                (batch_size * seq_len, action_dim)).data.clone()
            boundary_start, boundary_end = (
                action_feature_boundaries[i],
                action_feature_boundaries[i + 1],
            )
            action_features[:, boundary_start:
                            boundary_end] = action_features[:, boundary_start:
                                                            boundary_end].mean(
                                                                dim=0)
            action_features = action_features.reshape(
                (batch_size, seq_len, action_dim))
            new_tdp = TrainingBatch(
                training_input=MemoryNetworkInput(
                    state=tdp.training_input.state,
                    action=FeatureVector(float_features=action_features),
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    not_terminal=tdp.training_input.not_terminal,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(new_tdp,
                                           state_dim=state_dim,
                                           batch_first=True)
            feature_importance[i] = losses["loss"].cpu().detach().item(
            ) - orig_loss
            del losses

        for i in range(state_feature_num):
            state_features = tdp.training_input.state.float_features.reshape(
                (batch_size * seq_len, state_dim)).data.clone()
            boundary_start, boundary_end = (
                state_feature_boundaries[i],
                state_feature_boundaries[i + 1],
            )
            state_features[:, boundary_start:
                           boundary_end] = state_features[:, boundary_start:
                                                          boundary_end].mean(
                                                              dim=0)
            state_features = state_features.reshape(
                (batch_size, seq_len, state_dim))
            new_tdp = TrainingBatch(
                training_input=MemoryNetworkInput(
                    state=FeatureVector(float_features=state_features),
                    action=tdp.training_input.action,
                    next_state=tdp.training_input.next_state,
                    reward=tdp.training_input.reward,
                    not_terminal=tdp.training_input.not_terminal,
                ),
                extras=ExtraData(),
            )
            losses = self.trainer.get_loss(new_tdp,
                                           state_dim=state_dim,
                                           batch_first=True)
            feature_importance[i + action_feature_num] = (
                losses["loss"].cpu().detach().item() - orig_loss)
            del losses

        self.trainer.mdnrnn.mdnrnn.train()
        logger.info("**** Debug tool feature importance ****: {}".format(
            feature_importance))
        return {"feature_loss_increase": feature_importance.numpy()}