예제 #1
0
 def input_prototype(self):
     return rlt.PreprocessedStateAction(
         state=rlt.FeatureVector(
             float_features=torch.randn(1, 1, self.state_dim)),
         action=rlt.FeatureVector(
             float_features=torch.randn(1, 1, self.action_dim)),
     )
예제 #2
0
 def get_detached_q_values(
         self, state, action) -> Tuple[rlt.SingleQValue, rlt.SingleQValue]:
     """ Gets the q values from the model and target networks """
     input = rlt.PreprocessedStateAction(state=state, action=action)
     q_values = self.q_network(input)
     q_values_target = self.q_network_target(input)
     return q_values.q_value, q_values_target.q_value
예제 #3
0
 def forward(self, input):
     preprocessed_state = (self.state_preprocessor(input.state)
                           if self.state_preprocessor else input.state)
     preprocessed_action = (self.action_preprocessor(input.action)
                            if self.action_preprocessor else input.action)
     return self.q_network(
         rlt.PreprocessedStateAction(state=preprocessed_state,
                                     action=preprocessed_action))
예제 #4
0
    def train(self, training_batch: rlt.PreprocessedTrainingBatch):
        learning_input = training_batch.training_input
        assert isinstance(learning_input, rlt.PreprocessedSlateQInput)
        self.minibatch += 1

        reward = learning_input.reward
        reward_mask = learning_input.reward_mask
        not_done_mask = learning_input.not_terminal

        discount_tensor = torch.full_like(reward, self.gamma)

        if self.maxq_learning:
            raise NotImplementedError("Q-Learning for SlateQ is not implemented")
        else:
            # SARSA (Use the target network)
            next_q_values = self.get_detached_q_values_target(
                learning_input.tiled_next_state, learning_input.next_action
            )

        filtered_max_q_vals = next_q_values * not_done_mask.float()

        target_q_values = reward + (discount_tensor * filtered_max_q_vals)
        target_q_values = target_q_values[reward_mask]

        with torch.enable_grad():
            # Get Q-value of action taken
            current_state_action = rlt.PreprocessedStateAction(
                state=learning_input.tiled_state.as_preprocessed_feature_vector(),
                action=learning_input.action.as_preprocessed_feature_vector(),
            )
            q_values = self.q_network(current_state_action).q_value.view(
                *reward_mask.shape
            )[reward_mask]
            all_action_scores = q_values.detach()

            value_loss = self.q_network_loss(q_values, target_q_values)
            td_loss = value_loss.detach()
            value_loss.backward()
            self._maybe_run_optimizer(
                self.q_network_optimizer, self.minibatches_per_step
            )

        # Use the soft update rule to update target network
        self._maybe_soft_update(
            self.q_network, self.q_network_target, self.tau, self.minibatches_per_step
        )

        self.loss_reporter.report(
            td_loss=td_loss, model_values_on_logged_actions=all_action_scores
        )
예제 #5
0
 def get_slate_q_value(
     self,
     q_network,
     tiled_state: rlt.PreprocessedTiledFeatureVector,
     action: rlt.PreprocessedSlateFeatureVector,
 ) -> torch.Tensor:
     """ Gets the q values from the model and target networks """
     input = rlt.PreprocessedStateAction(
         state=tiled_state.as_preprocessed_feature_vector(),
         action=action.as_preprocessed_feature_vector(),
     )
     q_value = self.q_network_target(input).q_value
     q_value = (q_value.view(action.float_features.shape[0],
                             action.float_features.shape[1]) *
                action.item_mask * action.item_probability)
     return q_value.sum(dim=1, keepdim=True)
예제 #6
0
    def acc_rewards_of_one_solution(
        self, init_state: torch.Tensor, solution: torch.Tensor, solution_idx: int
    ):
        """
        ensemble_pop_size trajectories will be sampled to evaluate a
        CEM solution. Each trajectory is generated by one world model

        :param init_state: its shape is (state_dim, )
        :param solution: its shape is (plan_horizon_length, action_dim)
        :param solution_idx: the index of the solution
        :return reward: Reward of each of ensemble_pop_size trajectories
        """
        reward_matrix = np.zeros((self.ensemble_pop_size, self.plan_horizon_length))

        for i in range(self.ensemble_pop_size):
            state = init_state
            mem_net_idx = np.random.randint(0, len(self.mem_net_list))
            for j in range(self.plan_horizon_length):
                # world_model_input.state shape:
                # (1, 1, state_dim)
                # world_model_input.action shape:
                # (1, 1, action_dim)
                world_model_input = rlt.PreprocessedStateAction(
                    state=rlt.PreprocessedFeatureVector(
                        float_features=state.reshape((1, 1, self.state_dim))
                    ),
                    action=rlt.PreprocessedFeatureVector(
                        float_features=solution[j, :].reshape((1, 1, self.action_dim))
                    ),
                )
                reward, next_state, not_terminal, not_terminal_prob = self.sample_reward_next_state_terminal(
                    world_model_input, self.mem_net_list[mem_net_idx]
                )
                reward_matrix[i, j] = reward * (self.gamma ** j)

                if not not_terminal:
                    logger.debug(
                        f"Solution {solution_idx}: predict terminal at step {j}"
                        f" with prob. {1.0 - not_terminal_prob}"
                    )

                if not not_terminal:
                    break

                state = next_state

        return np.sum(reward_matrix, axis=1)
예제 #7
0
    def get_loss(
        self,
        training_batch: rlt.PreprocessedTrainingBatch,
        state_dim: Optional[int] = None,
        batch_first: bool = False,
    ):
        """
        Compute losses:
            GMMLoss(next_state, GMMPredicted) / (STATE_DIM + 2)
            + MSE(reward, predicted_reward)
            + BCE(not_terminal, logit_not_terminal)

        The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
            approximately linearly with STATE_DIM, the feature size of states. All losses
            are averaged both on the batch and the sequence dimensions (the two first
            dimensions).

        :param training_batch:
            training_batch.learning_input has these fields:
            - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor
            - reward: (BATCH_SIZE, SEQ_LEN) torch tensor
            - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor
            - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            the first two dimensions may be swapped depending on batch_first

        :param state_dim: the dimension of states. If provided, use it to normalize
            gmm loss

        :param batch_first: whether data's first dimension represents batch size. If
            FALSE, state, action, reward, not-terminal, and next_state's first
            two dimensions are SEQ_LEN and BATCH_SIZE.

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        learning_input = training_batch.training_input
        assert isinstance(learning_input, rlt.PreprocessedMemoryNetworkInput)
        # mdnrnn's input should have seq_len as the first dimension
        if batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                learning_input.state.float_features,
                learning_input.action,
                learning_input.next_state.float_features,
                learning_input.reward,
                learning_input.not_terminal,  # type: ignore
            )
            learning_input = rlt.PreprocessedMemoryNetworkInput(  # type: ignore
                state=rlt.PreprocessedFeatureVector(float_features=state),
                reward=reward,
                time_diff=torch.ones_like(reward).float(),
                action=action,
                not_terminal=not_terminal,
                next_state=rlt.PreprocessedFeatureVector(
                    float_features=next_state),
                step=None,
            )

        mdnrnn_input = rlt.PreprocessedStateAction(
            state=learning_input.state,  # type: ignore
            action=rlt.PreprocessedFeatureVector(
                float_features=learning_input.action),  # type: ignore
        )
        mdnrnn_output = self.mdnrnn(mdnrnn_input)
        mus, sigmas, logpi, rs, nts = (
            mdnrnn_output.mus,
            mdnrnn_output.sigmas,
            mdnrnn_output.logpi,
            mdnrnn_output.reward,
            mdnrnn_output.not_terminal,
        )

        next_state = learning_input.next_state.float_features
        not_terminal = learning_input.not_terminal  # type: ignore
        reward = learning_input.reward
        if self.params.fit_only_one_next_step:
            next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple(
                map(
                    lambda x: x[-1:],
                    (next_state, not_terminal, reward, mus, sigmas, logpi, nts,
                     rs),
                ))

        gmm = (gmm_loss(next_state, mus, sigmas, logpi) *
               self.params.next_state_loss_weight)
        bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) *
               self.params.not_terminal_loss_weight)
        mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight
        if state_dim is not None:
            loss = gmm / (state_dim + 2) + bce + mse
        else:
            loss = gmm + bce + mse
        return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
예제 #8
0
    def create_from_tensors_parametric_dqn(
        cls,
        trainer: ParametricDQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: rlt.PreprocessedFeatureVector,
        actions: rlt.PreprocessedFeatureVector,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: rlt.PreprocessedFeatureVector,
        max_num_actions: int,
        metrics: Optional[torch.Tensor] = None,
    ):
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)

        state_action_pairs = rlt.PreprocessedStateAction(state=states, action=actions)
        tiled_state = states.float_features.repeat(1, max_num_actions).reshape(
            -1, states.float_features.shape[1]
        )
        assert possible_actions is not None
        # Get Q-value of action taken
        possible_actions_state_concat = rlt.PreprocessedStateAction(
            state=rlt.PreprocessedFeatureVector(float_features=tiled_state),
            action=possible_actions,
        )

        # FIXME: model_values, model_values_for_logged_action, and model_metrics_values
        # should be calculated using q_network_cpe (as in discrete dqn).
        # q_network_cpe has not been added in parametric dqn yet.
        model_values = trainer.q_network(
            possible_actions_state_concat
        ).q_value  # type: ignore
        optimal_q_values, _ = trainer.get_detached_q_values(
            possible_actions_state_concat.state, possible_actions_state_concat.action
        )
        eval_action_idxs = None

        assert (
            model_values.shape[1] == 1
            and model_values.shape[0]
            == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
        ), (
            "Invalid shapes: "
            + str(model_values.shape)
            + " != "
            + str(possible_actions_mask.shape)
        )
        model_values = model_values.reshape(possible_actions_mask.shape)
        optimal_q_values = optimal_q_values.reshape(possible_actions_mask.shape)
        model_propensities = masked_softmax(
            optimal_q_values, possible_actions_mask, trainer.rl_temperature
        )

        rewards_and_metric_rewards = trainer.reward_network(
            possible_actions_state_concat
        ).q_value  # type: ignore
        model_rewards = rewards_and_metric_rewards[:, :1]
        assert (
            model_rewards.shape[0] * model_rewards.shape[1]
            == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
        ), (
            "Invalid shapes: "
            + str(model_rewards.shape)
            + " != "
            + str(possible_actions_mask.shape)
        )
        model_rewards = model_rewards.reshape(possible_actions_mask.shape)

        model_metrics = rewards_and_metric_rewards[:, 1:]
        model_metrics = model_metrics.reshape(possible_actions_mask.shape[0], -1)

        model_values_for_logged_action = trainer.q_network(state_action_pairs).q_value
        model_rewards_and_metrics_for_logged_action = trainer.reward_network(
            state_action_pairs
        ).q_value
        model_rewards_for_logged_action = model_rewards_and_metrics_for_logged_action[
            :, :1
        ]

        action_dim = possible_actions.float_features.shape[1]
        action_mask = torch.all(
            possible_actions.float_features.view(-1, max_num_actions, action_dim)
            == actions.float_features.unsqueeze(dim=1),
            dim=2,
        ).float()
        assert torch.all(action_mask.sum(dim=1) == 1)
        num_metrics = model_metrics.shape[1] // max_num_actions

        model_metrics_values = None
        model_metrics_for_logged_action = None
        model_metrics_values_for_logged_action = None
        if num_metrics > 0:
            # FIXME: calculate model_metrics_values when q_network_cpe is added
            # to parametric dqn
            model_metrics_values = model_values.repeat(1, num_metrics)

        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
예제 #9
0
    def train(self, training_batch) -> None:
        if isinstance(training_batch, TrainingDataPage):
            training_batch = training_batch.as_parametric_maxq_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        reward = learning_input.reward
        not_done_mask = learning_input.not_terminal

        discount_tensor = torch.full_like(reward, self.gamma)
        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        learning_input.time_diff.float())
        if self.multi_steps is not None:
            discount_tensor = torch.pow(self.gamma,
                                        learning_input.step.float())

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                learning_input.tiled_next_state,
                learning_input.possible_next_actions)
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                learning_input.possible_next_actions_mask.float(),
            )
        else:
            # SARSA (Use the target network)
            _, next_q_values = self.get_detached_q_values(
                learning_input.next_state, learning_input.next_action)

        filtered_max_q_vals = next_q_values * not_done_mask.float()

        target_q_values = reward + (discount_tensor * filtered_max_q_vals)

        with torch.enable_grad():
            # Get Q-value of action taken
            current_state_action = rlt.PreprocessedStateAction(
                state=learning_input.state, action=learning_input.action)
            q_values = self.q_network(current_state_action).q_value
            self.all_action_scores = q_values.detach()

            value_loss = self.q_network_loss(q_values, target_q_values)
            self.loss = value_loss.detach()
            value_loss.backward()
            self._maybe_run_optimizer(self.q_network_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update target network
        self._maybe_soft_update(self.q_network, self.q_network_target,
                                self.tau, self.minibatches_per_step)

        with torch.enable_grad():
            if training_batch.extras.metrics is not None:
                metrics_reward_concat_real_vals = torch.cat(
                    (reward, training_batch.extras.metrics), dim=1)
            else:
                metrics_reward_concat_real_vals = reward
            # get reward estimates
            reward_estimates = self.reward_network(
                current_state_action).q_value
            reward_loss = F.mse_loss(reward_estimates,
                                     metrics_reward_concat_real_vals)
            reward_loss.backward()
            self._maybe_run_optimizer(self.reward_network_optimizer,
                                      self.minibatches_per_step)

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            logged_rewards=reward,
            model_values_on_logged_actions=self.all_action_scores,
        )
예제 #10
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_policy_network_training_batch"):
            training_batch = training_batch.as_policy_network_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        next_state = learning_input.next_state
        reward = learning_input.reward
        not_done_mask = learning_input.not_terminal

        action = self._maybe_scale_action_in_train(action.float_features)

        max_action = (self.max_action_range_tensor_training
                      if self.max_action_range_tensor_training else torch.ones(
                          action.shape, device=self.device))
        min_action = (self.min_action_range_tensor_serving
                      if self.min_action_range_tensor_serving else
                      -torch.ones(action.shape, device=self.device))

        # Compute current value estimates
        current_state_action = rlt.PreprocessedStateAction(
            state=state,
            action=rlt.PreprocessedFeatureVector(float_features=action))
        q1_value = self.q1_network(current_state_action).q_value
        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
        actor_action = self.actor_network(
            rlt.PreprocessedState(state=state)).action

        # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s')))
        with torch.no_grad():
            next_actor = self.actor_network_target(
                rlt.PreprocessedState(state=next_state)).action
            next_actor += (torch.randn_like(next_actor) *
                           self.target_policy_smoothing).clamp(
                               -self.noise_clip, self.noise_clip)
            next_actor = torch.max(torch.min(next_actor, max_action),
                                   min_action)
            next_state_actor = rlt.PreprocessedStateAction(
                state=next_state,
                action=rlt.PreprocessedFeatureVector(
                    float_features=next_actor),
            )
            next_state_value = self.q1_network_target(next_state_actor).q_value

            if self.q2_network is not None:
                next_state_value = torch.min(
                    next_state_value,
                    self.q2_network_target(next_state_actor).q_value)

            target_q_value = (
                reward + self.gamma * next_state_value * not_done_mask.float())

        # Optimize Q1 and Q2
        q1_loss = F.mse_loss(q1_value, target_q_value)
        q1_loss.backward()
        self._maybe_run_optimizer(self.q1_network_optimizer,
                                  self.minibatches_per_step)
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            q2_loss.backward()
            self._maybe_run_optimizer(self.q2_network_optimizer,
                                      self.minibatches_per_step)

        # Only update actor and target networks after a fixed number of Q updates
        if self.minibatch % self.delayed_policy_update == 0:
            actor_loss = -self.q1_network(
                rlt.PreprocessedStateAction(
                    state=state,
                    action=rlt.PreprocessedFeatureVector(
                        float_features=actor_action),
                )).q_value.mean()
            actor_loss.backward()
            self._maybe_run_optimizer(self.actor_network_optimizer,
                                      self.minibatches_per_step)

            # Use the soft update rule to update the target networks
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            self._maybe_soft_update(
                self.actor_network,
                self.actor_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq != 0
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
        )
예제 #11
0
    def create_from_tensors(
        cls,
        trainer: DQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: rlt.PreprocessedFeatureVector,
        actions: rlt.PreprocessedFeatureVector,
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: Optional[rlt.PreprocessedFeatureVector] = None,
        max_num_actions: Optional[int] = None,
        metrics: Optional[torch.Tensor] = None,
    ):
        # Switch to evaluation mode for the network
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)

        if max_num_actions:
            # Parametric model CPE
            state_action_pairs = rlt.PreprocessedStateAction(
                state=states, action=actions
            )
            tiled_state = states.float_features.repeat(1, max_num_actions).reshape(
                -1, states.float_features.shape[1]
            )
            assert possible_actions is not None
            # Get Q-value of action taken
            possible_actions_state_concat = rlt.PreprocessedStateAction(
                state=rlt.PreprocessedFeatureVector(float_features=tiled_state),
                action=possible_actions,
            )

            # Parametric actions
            # FIXME: model_values and model propensities should be calculated
            # as in discrete dqn model
            model_values = trainer.q_network(
                possible_actions_state_concat
            ).q_value  # type: ignore
            optimal_q_values = model_values
            eval_action_idxs = None

            assert (
                model_values.shape[0] * model_values.shape[1]
                == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
            ), (
                "Invalid shapes: "
                + str(model_values.shape)
                + " != "
                + str(possible_actions_mask.shape)
            )
            model_values = model_values.reshape(possible_actions_mask.shape)
            model_propensities = masked_softmax(
                model_values, possible_actions_mask, trainer.rl_temperature
            )

            model_rewards = trainer.reward_network(
                possible_actions_state_concat
            ).q_value  # type: ignore
            assert (
                model_rewards.shape[0] * model_rewards.shape[1]
                == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
            ), (
                "Invalid shapes: "
                + str(model_rewards.shape)
                + " != "
                + str(possible_actions_mask.shape)
            )
            model_rewards = model_rewards.reshape(possible_actions_mask.shape)

            model_values_for_logged_action = trainer.q_network(
                state_action_pairs
            ).q_value
            model_rewards_for_logged_action = trainer.reward_network(
                state_action_pairs
            ).q_value

            action_mask = (
                torch.abs(model_values - model_values_for_logged_action) < 1e-3
            ).float()

            model_metrics = None
            model_metrics_for_logged_action = None
            model_metrics_values = None
            model_metrics_values_for_logged_action = None
        else:
            num_actions = trainer.num_actions
            action_mask = actions.float()  # type: ignore

            # Switch to evaluation mode for the network
            old_q_cpe_train_state = trainer.q_network_cpe.training
            trainer.q_network_cpe.train(False)

            # Discrete actions
            rewards = trainer.boost_rewards(rewards, actions)  # type: ignore
            model_values = trainer.q_network_cpe(
                rlt.PreprocessedState(state=states)
            ).q_values[:, 0:num_actions]
            optimal_q_values = trainer.get_detached_q_values(
                states  # type: ignore
            )[  # type: ignore
                0
            ]  # type: ignore
            eval_action_idxs = trainer.get_max_q_values(  # type: ignore
                optimal_q_values, possible_actions_mask
            )[1]
            model_propensities = masked_softmax(
                optimal_q_values, possible_actions_mask, trainer.rl_temperature
            )
            assert model_values.shape == actions.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_values.shape)  # type: ignore
                + " != "
                + str(actions.shape)  # type: ignore
            )
            assert model_values.shape == possible_actions_mask.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_values.shape)  # type: ignore
                + " != "
                + str(possible_actions_mask.shape)  # type: ignore
            )
            model_values_for_logged_action = torch.sum(
                model_values * action_mask, dim=1, keepdim=True
            )

            rewards_and_metric_rewards = trainer.reward_network(
                rlt.PreprocessedState(state=states)
            )

            # In case we reuse the modular for Q-network
            if hasattr(rewards_and_metric_rewards, "q_values"):
                rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

            model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
            assert model_rewards.shape == actions.shape, (  # type: ignore
                "Invalid shape: "
                + str(model_rewards.shape)  # type: ignore
                + " != "
                + str(actions.shape)  # type: ignore
            )
            model_rewards_for_logged_action = torch.sum(
                model_rewards * action_mask, dim=1, keepdim=True
            )

            model_metrics = rewards_and_metric_rewards[:, num_actions:]

            assert model_metrics.shape[1] % num_actions == 0, (
                "Invalid metrics shape: "
                + str(model_metrics.shape)
                + " "
                + str(num_actions)
            )
            num_metrics = model_metrics.shape[1] // num_actions

            if num_metrics == 0:
                model_metrics_values = None
                model_metrics_for_logged_action = None
                model_metrics_values_for_logged_action = None
            else:
                model_metrics_values = trainer.q_network_cpe(
                    rlt.PreprocessedState(state=states)
                )
                # Backward compatility
                if hasattr(model_metrics_values, "q_values"):
                    model_metrics_values = model_metrics_values.q_values
                model_metrics_values = model_metrics_values[:, num_actions:]
                assert (
                    model_metrics_values.shape[1] == num_actions * num_metrics
                ), (  # type: ignore
                    "Invalid shape: "
                    + str(model_metrics_values.shape[1])  # type: ignore
                    + " != "
                    + str(actions.shape[1] * num_metrics)  # type: ignore
                )

                model_metrics_for_logged_action_list = []
                model_metrics_values_for_logged_action_list = []
                for metric_index in range(num_metrics):
                    metric_start = metric_index * num_actions
                    metric_end = (metric_index + 1) * num_actions
                    model_metrics_for_logged_action_list.append(
                        torch.sum(
                            model_metrics[:, metric_start:metric_end] * action_mask,
                            dim=1,
                            keepdim=True,
                        )
                    )

                    model_metrics_values_for_logged_action_list.append(
                        torch.sum(
                            model_metrics_values[:, metric_start:metric_end]
                            * action_mask,
                            dim=1,
                            keepdim=True,
                        )
                    )
                model_metrics_for_logged_action = torch.cat(
                    model_metrics_for_logged_action_list, dim=1
                )
                model_metrics_values_for_logged_action = torch.cat(
                    model_metrics_values_for_logged_action_list, dim=1
                )

            # Switch back to the old mode
            trainer.q_network_cpe.train(old_q_cpe_train_state)  # type: ignore

        # Switch back to the old mode
        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
예제 #12
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_policy_network_training_batch"):
            training_batch = training_batch.as_policy_network_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = action._replace(
                float_features=rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                )
            )

        with torch.enable_grad():
            #
            # First, optimize Q networks; minimizing MSE between
            # Q(s, a) & r + discount * V'(next_s)
            #

            current_state_action = rlt.PreprocessedStateAction(
                state=state, action=action
            )
            q1_value = self.q1_network(current_state_action).q_value
            if self.q2_network:
                q2_value = self.q2_network(current_state_action).q_value
            actor_output = self.actor_network(rlt.PreprocessedState(state=state))

            # Optimize Alpha
            if self.alpha_optimizer is not None:
                alpha_loss = -(
                    self.log_alpha
                    * (actor_output.log_prob + self.target_entropy).detach()
                ).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.entropy_temperature = self.log_alpha.exp()

            with torch.no_grad():
                if self.value_network is not None:
                    next_state_value = self.value_network_target(
                        learning_input.next_state.float_features
                    )
                else:
                    next_state_actor_output = self.actor_network(
                        rlt.PreprocessedState(state=learning_input.next_state)
                    )
                    next_state_actor_action = rlt.PreprocessedStateAction(
                        state=learning_input.next_state,
                        action=rlt.PreprocessedFeatureVector(
                            float_features=next_state_actor_output.action
                        ),
                    )
                    next_state_value = self.q1_network_target(
                        next_state_actor_action
                    ).q_value

                    if self.q2_network is not None:
                        target_q2_value = self.q2_network_target(
                            next_state_actor_action
                        ).q_value
                        next_state_value = torch.min(next_state_value, target_q2_value)

                    log_prob_a = self.actor_network.get_log_prob(
                        learning_input.next_state, next_state_actor_output.action
                    )
                    log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                    next_state_value -= self.entropy_temperature * log_prob_a

                target_q_value = (
                    reward + discount * next_state_value * not_done_mask.float()
                )

            q1_loss = F.mse_loss(q1_value, target_q_value)
            q1_loss.backward()
            self._maybe_run_optimizer(
                self.q1_network_optimizer, self.minibatches_per_step
            )
            if self.q2_network:
                q2_loss = F.mse_loss(q2_value, target_q_value)
                q2_loss.backward()
                self._maybe_run_optimizer(
                    self.q2_network_optimizer, self.minibatches_per_step
                )

            #
            # Second, optimize the actor; minimizing KL-divergence between action propensity
            # & softmax of value. Due to reparameterization trick, it ends up being
            # log_prob(actor_action) - Q(s, actor_action)
            #

            state_actor_action = rlt.PreprocessedStateAction(
                state=state,
                action=rlt.PreprocessedFeatureVector(
                    float_features=actor_output.action
                ),
            )
            q1_actor_value = self.q1_network(state_actor_action).q_value
            min_q_actor_value = q1_actor_value
            if self.q2_network:
                q2_actor_value = self.q2_network(state_actor_action).q_value
                min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

            actor_loss = (
                self.entropy_temperature * actor_output.log_prob - min_q_actor_value
            )
            # Do this in 2 steps so we can log histogram of actor loss
            actor_loss_mean = actor_loss.mean()
            actor_loss_mean.backward()
            self._maybe_run_optimizer(
                self.actor_network_optimizer, self.minibatches_per_step
            )

            #
            # Lastly, if applicable, optimize value network; minimizing MSE between
            # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
            #

            if self.value_network is not None:
                state_value = self.value_network(state.float_features)

                if self.logged_action_uniform_prior:
                    log_prob_a = torch.zeros_like(min_q_actor_value)
                    target_value = min_q_actor_value
                else:
                    with torch.no_grad():
                        log_prob_a = actor_output.log_prob
                        log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                        target_value = (
                            min_q_actor_value - self.entropy_temperature * log_prob_a
                        )

                value_loss = F.mse_loss(state_value, target_value.detach())
                value_loss.backward()
                self._maybe_run_optimizer(
                    self.value_network_optimizer, self.minibatches_per_step
                )

        # Use the soft update rule to update the target networks
        if self.value_network is not None:
            self._maybe_soft_update(
                self.value_network,
                self.value_network_target,
                self.tau,
                self.minibatches_per_step,
            )
        else:
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (
            self.tensorboard_logging_freq is not None
            and self.minibatch % self.tensorboard_logging_freq == 0
        ):
            SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            if self.value_network:
                SummaryWriterContext.add_histogram("value_network/target", target_value)

            SummaryWriterContext.add_histogram(
                "q_network/next_state_value", next_state_value
            )
            SummaryWriterContext.add_histogram(
                "q_network/target_q_value", target_q_value
            )
            SummaryWriterContext.add_histogram(
                "actor/min_q_actor_value", min_q_actor_value
            )
            SummaryWriterContext.add_histogram(
                "actor/action_log_prob", actor_output.log_prob
            )
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )