def create_from_tensors(
        cls,
        trainer: DQNTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: Union[mt.State, torch.Tensor],
        actions: Union[mt.Action, torch.Tensor],
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: Optional[mt.FeatureVector] = None,
        max_num_actions: Optional[int] = None,
        metrics: Optional[torch.Tensor] = None,
    ):
        # Switch to evaluation mode for the network
        old_q_train_state = trainer.q_network.training
        old_reward_train_state = trainer.reward_network.training
        trainer.q_network.train(False)
        trainer.reward_network.train(False)

        if max_num_actions:
            # Parametric model CPE
            state_action_pairs = mt.StateAction(  # type: ignore
                state=states, action=actions)
            tiled_state = mt.FeatureVector(
                states.float_features.repeat(  # type: ignore
                    1, max_num_actions).reshape(  # type: ignore
                        -1,
                        states.float_features.shape[1]  # type: ignore
                    ))
            # Get Q-value of action taken
            possible_actions_state_concat = mt.StateAction(  # type: ignore
                state=tiled_state,
                action=possible_actions  # type: ignore
            )

            # Parametric actions
            # FIXME: model_values and model propensities should be calculated
            # as in discrete dqn model
            model_values = trainer.q_network(
                possible_actions_state_concat).q_value  # type: ignore
            optimal_q_values = model_values
            eval_action_idxs = None

            assert (model_values.shape[0] *
                    model_values.shape[1] == possible_actions_mask.shape[0] *
                    possible_actions_mask.shape[1]), (
                        "Invalid shapes: " + str(model_values.shape) + " != " +
                        str(possible_actions_mask.shape))
            model_values = model_values.reshape(possible_actions_mask.shape)
            model_propensities = masked_softmax(model_values,
                                                possible_actions_mask,
                                                trainer.rl_temperature)

            model_rewards = trainer.reward_network(
                possible_actions_state_concat).q_value  # type: ignore
            assert (model_rewards.shape[0] *
                    model_rewards.shape[1] == possible_actions_mask.shape[0] *
                    possible_actions_mask.shape[1]), (
                        "Invalid shapes: " + str(model_rewards.shape) +
                        " != " + str(possible_actions_mask.shape))
            model_rewards = model_rewards.reshape(possible_actions_mask.shape)

            model_values_for_logged_action = trainer.q_network(
                state_action_pairs).q_value
            model_rewards_for_logged_action = trainer.reward_network(
                state_action_pairs).q_value

            action_mask = (
                torch.abs(model_values - model_values_for_logged_action) <
                1e-3).float()

            model_metrics = None
            model_metrics_for_logged_action = None
            model_metrics_values = None
            model_metrics_values_for_logged_action = None
        else:
            if isinstance(states, mt.State):
                states = mt.StateInput(state=states)  # type: ignore

            num_actions = trainer.num_actions
            action_mask = actions.float()  # type: ignore

            # Switch to evaluation mode for the network
            old_q_cpe_train_state = trainer.q_network_cpe.training
            trainer.q_network_cpe.train(False)

            # Discrete actions
            rewards = trainer.boost_rewards(rewards, actions)  # type: ignore
            model_values = trainer.q_network_cpe(
                states).q_values[:, 0:num_actions]
            optimal_q_values = trainer.get_detached_q_values(
                states.state  # type: ignore
            )[  # type: ignore
                0]  # type: ignore
            eval_action_idxs = trainer.get_max_q_values(  # type: ignore
                optimal_q_values, possible_actions_mask)[1]
            model_propensities = masked_softmax(optimal_q_values,
                                                possible_actions_mask,
                                                trainer.rl_temperature)
            assert model_values.shape == actions.shape, (  # type: ignore
                "Invalid shape: " + str(model_values.shape)  # type: ignore
                + " != " + str(actions.shape)  # type: ignore
            )
            assert model_values.shape == possible_actions_mask.shape, (  # type: ignore
                "Invalid shape: " + str(model_values.shape)  # type: ignore
                + " != " + str(possible_actions_mask.shape)  # type: ignore
            )
            model_values_for_logged_action = torch.sum(model_values *
                                                       action_mask,
                                                       dim=1,
                                                       keepdim=True)

            rewards_and_metric_rewards = trainer.reward_network(states)

            # In case we reuse the modular for Q-network
            if hasattr(rewards_and_metric_rewards, "q_values"):
                rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

            model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
            assert model_rewards.shape == actions.shape, (  # type: ignore
                "Invalid shape: " + str(model_rewards.shape)  # type: ignore
                + " != " + str(actions.shape)  # type: ignore
            )
            model_rewards_for_logged_action = torch.sum(model_rewards *
                                                        action_mask,
                                                        dim=1,
                                                        keepdim=True)

            model_metrics = rewards_and_metric_rewards[:, num_actions:]

            assert model_metrics.shape[1] % num_actions == 0, (
                "Invalid metrics shape: " + str(model_metrics.shape) + " " +
                str(num_actions))
            num_metrics = model_metrics.shape[1] // num_actions

            if num_metrics == 0:
                model_metrics_values = None
                model_metrics_for_logged_action = None
                model_metrics_values_for_logged_action = None
            else:
                model_metrics_values = trainer.q_network_cpe(states)
                # Backward compatility
                if hasattr(model_metrics_values, "q_values"):
                    model_metrics_values = model_metrics_values.q_values
                model_metrics_values = model_metrics_values[:, num_actions:]
                assert (model_metrics_values.shape[1] == num_actions *
                        num_metrics), (  # type: ignore
                            "Invalid shape: " +
                            str(model_metrics_values.shape[1])  # type: ignore
                            + " != " +
                            str(actions.shape[1] * num_metrics)  # type: ignore
                        )

                model_metrics_for_logged_action_list = []
                model_metrics_values_for_logged_action_list = []
                for metric_index in range(num_metrics):
                    metric_start = metric_index * num_actions
                    metric_end = (metric_index + 1) * num_actions
                    model_metrics_for_logged_action_list.append(
                        torch.sum(
                            model_metrics[:, metric_start:metric_end] *
                            action_mask,
                            dim=1,
                            keepdim=True,
                        ))

                    model_metrics_values_for_logged_action_list.append(
                        torch.sum(
                            model_metrics_values[:, metric_start:metric_end] *
                            action_mask,
                            dim=1,
                            keepdim=True,
                        ))
                model_metrics_for_logged_action = torch.cat(
                    model_metrics_for_logged_action_list, dim=1)
                model_metrics_values_for_logged_action = torch.cat(
                    model_metrics_values_for_logged_action_list, dim=1)

            # Switch back to the old mode
            trainer.q_network_cpe.train(old_q_cpe_train_state)  # type: ignore

        # Switch back to the old mode
        trainer.q_network.train(old_q_train_state)  # type: ignore
        trainer.reward_network.train(old_reward_train_state)  # type: ignore

        return cls(
            mdp_id=mdp_ids,
            sequence_number=sequence_numbers,
            logged_propensities=propensities,
            logged_rewards=rewards,
            action_mask=action_mask,
            model_rewards=model_rewards,
            model_rewards_for_logged_action=model_rewards_for_logged_action,
            model_values=model_values,
            model_values_for_logged_action=model_values_for_logged_action,
            model_metrics_values=model_metrics_values,
            model_metrics_values_for_logged_action=
            model_metrics_values_for_logged_action,
            model_propensities=model_propensities,
            logged_metrics=metrics,
            model_metrics=model_metrics,
            model_metrics_for_logged_action=model_metrics_for_logged_action,
            # Will compute later
            logged_values=None,
            logged_metrics_values=None,
            possible_actions_mask=possible_actions_mask,
            optimal_q_values=optimal_q_values,
            eval_action_idxs=eval_action_idxs,
        )
Exemple #2
0
    def train(self, training_batch):
        if isinstance(training_batch, TrainingDataPage):
            if self.maxq_learning:
                training_batch = training_batch.as_discrete_maxq_training_batch(
                )
            else:
                training_batch = training_batch.as_discrete_sarsa_training_batch(
                )

        learning_input = training_batch.training_input
        boosted_rewards = self.boost_rewards(learning_input.reward,
                                             learning_input.action)

        self.minibatch += 1
        rewards = boosted_rewards
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = learning_input.not_terminal.float()

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        learning_input.time_diff.float())
        if self.multi_steps is not None:
            discount_tensor = torch.pow(self.gamma,
                                        learning_input.step.float())

        all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
            learning_input.next_state)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            possible_next_actions_mask = (
                learning_input.possible_next_actions_mask.float())
            if self.bcq:
                action_on_policy = get_valid_actions_from_imitator(
                    self.bcq_imitator,
                    learning_input.next_state.float_features,
                    self.bcq_drop_threshold,
                )
                possible_next_actions_mask *= action_on_policy
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values, all_next_q_values_target,
                possible_next_actions_mask)
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values, all_next_q_values_target,
                learning_input.next_action)

        filtered_next_q_vals = next_q_values * not_done_mask

        target_q_values = rewards + (discount_tensor * filtered_next_q_vals)

        # Get Q-value of action taken
        current_state = rlt.StateInput(state=learning_input.state)
        all_q_values = self.q_network(current_state).q_values
        self.all_action_scores = all_q_values.detach()
        q_values = torch.sum(all_q_values * learning_input.action,
                             1,
                             keepdim=True)

        loss = self.q_network_loss(q_values, target_q_values)
        self.loss = loss.detach()

        self.q_network_optimizer.zero_grad()
        loss.backward()
        self.q_network_optimizer.step()

        # Use the soft update rule to update target network
        self._soft_update(self.q_network, self.q_network_target, self.tau)

        # Get Q-values of next states, used in computing cpe
        with torch.no_grad():
            next_state = rlt.StateInput(state=learning_input.next_state)
            all_next_action_scores = self.q_network(
                next_state).q_values.detach()

        logged_action_idxs = learning_input.action.argmax(dim=1, keepdim=True)
        reward_loss, model_rewards, model_propensities = self.calculate_cpes(
            training_batch,
            current_state,
            next_state,
            all_next_action_scores,
            logged_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        if self.maxq_learning:
            possible_actions_mask = learning_input.possible_actions_mask

        if self.bcq:
            action_on_policy = get_valid_actions_from_imitator(
                self.bcq_imitator,
                learning_input.state.float_features,
                self.bcq_drop_threshold,
            )
            possible_actions_mask *= action_on_policy

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_propensities=model_propensities,
            model_rewards=model_rewards,
            model_values=self.all_action_scores,
            model_values_on_logged_actions=
            None,  # Compute at end of each epoch for CPE
            model_action_idxs=self.get_max_q_values(
                self.all_action_scores,
                possible_actions_mask
                if self.maxq_learning else learning_input.action,
            )[1],
        )
Exemple #3
0
 def input_prototype(self):
     return rlt.StateInput(state=rlt.FeatureVector(
         float_features=torch.randn(1, self.state_dim)))
Exemple #4
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        next_state = learning_input.next_state
        reward = learning_input.reward
        not_done_mask = learning_input.not_terminal

        action = self._maybe_scale_action_in_train(action)

        max_action = (self.max_action_range_tensor_training
                      if self.max_action_range_tensor_training else torch.ones(
                          action.float_features.shape).type(self.dtype))
        min_action = (
            self.min_action_range_tensor_serving
            if self.min_action_range_tensor_serving else
            -torch.ones(action.float_features.shape).type(self.dtype))

        # Compute current value estimates
        current_state_action = rlt.StateAction(state=state, action=action)
        q1_value = self.q1_network(current_state_action).q_value
        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
        actor_action = self.actor_network(rlt.StateInput(state=state)).action

        # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s')))
        with torch.no_grad():
            next_actor = self.actor_network_target(
                rlt.StateInput(state=next_state)).action
            next_actor += (torch.randn_like(next_actor) *
                           self.target_policy_smoothing).clamp(
                               -self.noise_clip, self.noise_clip)
            next_actor = torch.max(torch.min(next_actor, max_action),
                                   min_action)
            next_state_actor = rlt.StateAction(
                state=next_state,
                action=rlt.FeatureVector(float_features=next_actor))
            next_state_value = self.q1_network_target(next_state_actor).q_value

            if self.q2_network is not None:
                next_state_value = torch.min(
                    next_state_value,
                    self.q2_network_target(next_state_actor).q_value)

            target_q_value = (
                reward + self.gamma * next_state_value * not_done_mask.float())

        # Optimize Q1 and Q2
        q1_loss = F.mse_loss(q1_value, target_q_value)
        q1_loss.backward()
        self._maybe_run_optimizer(self.q1_network_optimizer,
                                  self.minibatches_per_step)
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            q2_loss.backward()
            self._maybe_run_optimizer(self.q2_network_optimizer,
                                      self.minibatches_per_step)

        # Only update actor and target networks after a fixed number of Q updates
        if self.minibatch % self.delayed_policy_update == 0:
            actor_loss = -self.q1_network(
                rlt.StateAction(
                    state=state,
                    action=rlt.FeatureVector(
                        float_features=actor_action))).q_value.mean()
            actor_loss.backward()
            self._maybe_run_optimizer(self.actor_network_optimizer,
                                      self.minibatches_per_step)

            # Use the soft update rule to update the target networks
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            self._maybe_soft_update(
                self.actor_network,
                self.actor_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq is not None
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
        )
Exemple #5
0
 def input_prototype(self):
     return rlt.StateInput(state=self.state_preprocessor.input_prototype())
Exemple #6
0
 def forward(self, input):
     preprocessed_state = self.state_preprocessor(input.state)
     return self.actor_network(rlt.StateInput(state=preprocessed_state))
Exemple #7
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = rlt.FeatureVector(
                rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                ))

        #
        # First, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        current_state_action = rlt.StateAction(state=state, action=action)
        q1_value = self.q1_network(current_state_action).q_value
        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value

        with torch.no_grad():
            if self.value_network is not None:
                next_state_value = self.value_network_target(
                    learning_input.next_state.float_features)
            else:
                actor_output = self.actor_network(
                    rlt.StateInput(state=learning_input.next_state))
                next_state_actor_action = rlt.StateAction(
                    state=learning_input.next_state,
                    action=rlt.FeatureVector(
                        float_features=actor_output.action),
                )
                next_state_value = self.q1_network_target(
                    next_state_actor_action).q_value

                if self.q2_network is not None:
                    target_q2_value = self.q2_network_target(
                        next_state_actor_action).q_value
                    next_state_value = torch.min(next_state_value,
                                                 target_q2_value)

                log_prob_a = self.actor_network.get_log_prob(
                    learning_input.next_state, actor_output.action)
                log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                next_state_value -= self.entropy_temperature * log_prob_a

            target_q_value = (
                reward + discount * next_state_value * not_done_mask.float())

        q1_loss = F.mse_loss(q1_value, target_q_value)
        q1_loss.backward()
        self._maybe_run_optimizer(self.q1_network_optimizer,
                                  self.minibatches_per_step)
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            q2_loss.backward()
            self._maybe_run_optimizer(self.q2_network_optimizer,
                                      self.minibatches_per_step)

        #
        # Second, optimize the actor; minimizing KL-divergence between action propensity
        # & softmax of value. Due to reparameterization trick, it ends up being
        # log_prob(actor_action) - Q(s, actor_action)
        #

        actor_output = self.actor_network(rlt.StateInput(state=state))

        state_actor_action = rlt.StateAction(
            state=state,
            action=rlt.FeatureVector(float_features=actor_output.action))
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = (self.entropy_temperature * actor_output.log_prob -
                      min_q_actor_value)
        # Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()
        actor_loss_mean.backward()
        self._maybe_run_optimizer(self.actor_network_optimizer,
                                  self.minibatches_per_step)

        #
        # Lastly, if applicable, optimize value network; minimizing MSE between
        # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
        #

        if self.value_network is not None:
            state_value = self.value_network(state.float_features)

            if self.logged_action_uniform_prior:
                log_prob_a = torch.zeros_like(min_q_actor_value)
                target_value = min_q_actor_value
            else:
                with torch.no_grad():
                    log_prob_a = actor_output.log_prob
                    log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                    target_value = (min_q_actor_value -
                                    self.entropy_temperature * log_prob_a)

            value_loss = F.mse_loss(state_value, target_value.detach())
            value_loss.backward()
            self._maybe_run_optimizer(self.value_network_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update the target networks
        if self.value_network is not None:
            self._maybe_soft_update(
                self.value_network,
                self.value_network_target,
                self.tau,
                self.minibatches_per_step,
            )
        else:
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq is not None
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            if self.value_network:
                SummaryWriterContext.add_histogram("value_network/target",
                                                   target_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/min_q_actor_value",
                                               min_q_actor_value)
            SummaryWriterContext.add_histogram("actor/action_log_prob",
                                               actor_output.log_prob)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )
Exemple #8
0
    def train(self, training_batch, evaluator=None) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        current_state_action = rlt.StateAction(state=state, action=action)

        q1_value = self.q1_network(current_state_action).q_value
        min_q_value = q1_value

        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
            min_q_value = torch.min(q1_value, q2_value)

        # Use the minimum as target, ensure no gradient going through
        min_q_value = min_q_value.detach()

        #
        # First, optimize value network; minimizing MSE between
        # V(s) & Q(s, a) - log(pi(a|s))
        #

        state_value = self.value_network(state.float_features)  # .q_value

        with torch.no_grad():
            log_prob_a = self.actor_network.get_log_prob(state, action.float_features)
            target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        self.value_network_optimizer.zero_grad()
        value_loss.backward()
        self.value_network_optimizer.step()

        #
        # Second, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        with torch.no_grad():
            next_state_value = (
                self.value_network_target(learning_input.next_state.float_features)
                * not_done_mask
            )

            if self.minibatch < self.reward_burnin:
                target_q_value = reward
            else:
                target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        self.q1_network_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_network_optimizer.step()
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            self.q2_network_optimizer.zero_grad()
            q2_loss.backward()
            self.q2_network_optimizer.step()

        #
        # Lastly, optimize the actor; minimizing KL-divergence between action propensity
        # & softmax of value. Due to reparameterization trick, it ends up being
        # log_prob(actor_action) - Q(s, actor_action)
        #

        actor_output = self.actor_network(rlt.StateInput(state=state))

        state_actor_action = rlt.StateAction(
            state=state, action=rlt.FeatureVector(float_features=actor_output.action)
        )
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = (
            self.entropy_temperature * actor_output.log_prob - min_q_actor_value
        )
        # Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()
        self.actor_network_optimizer.zero_grad()
        actor_loss_mean.backward()
        self.actor_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.value_network, self.value_network_target, 1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.value_network, self.value_network_target, self.tau)

        # Logging at the end to schedule all the cuda operations first
        if (
            self.tensorboard_logging_freq is not None
            and self.minibatch % self.tensorboard_logging_freq == 0
        ):
            SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            SummaryWriterContext.add_histogram("min_q/logged_state_value", min_q_value)
            SummaryWriterContext.add_histogram("value_network/target", target_value)
            SummaryWriterContext.add_histogram(
                "q_network/next_state_value", next_state_value
            )
            SummaryWriterContext.add_histogram(
                "q_network/target_q_value", target_q_value
            )
            SummaryWriterContext.add_histogram(
                "actor/min_q_actor_value", min_q_actor_value
            )
            SummaryWriterContext.add_histogram(
                "actor/action_log_prob", actor_output.log_prob
            )
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        if evaluator is not None:
            cpe_stats = BatchStatsForCPE(
                td_loss=q1_loss.detach().cpu().numpy(),
                model_values_on_logged_actions=q1_value.detach().cpu().numpy(),
            )
            evaluator.report(cpe_stats)
    def train(self, training_batch):
        if isinstance(training_batch, TrainingDataPage):
            training_batch = training_batch.as_discrete_sarsa_training_batch()

        learning_input = training_batch.training_input
        # Apply reward boost if specified
        reward_boosts = torch.sum(learning_input.action * self.reward_boosts,
                                  dim=1,
                                  keepdim=True)
        boosted_rewards = learning_input.reward + reward_boosts

        self.minibatch += 1
        rewards = boosted_rewards
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            # TODO: Implement this in another diff
            logger.warning(
                "_dqn_trainer has not implemented use_seq_num_diff_as_time_diff feature"
            )
            pass

        all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
            learning_input.next_state)

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values.q_values,
                all_next_q_values_target.q_values
                if self.double_q_learning else None,
                learning_input.possible_next_actions_mask.float(),
            )
        else:
            # SARSA
            next_q_values, max_q_action_idxs = self.get_max_q_values_with_target(
                all_next_q_values.q_values,
                all_next_q_values_target.q_values
                if self.double_q_learning else None,
                learning_input.next_action,
            )

        filtered_next_q_vals = next_q_values * not_done_mask.float()

        if self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor *
                                         filtered_next_q_vals)

        # Get Q-value of action taken
        current_state = rlt.StateInput(state=learning_input.state)
        all_q_values = self.q_network(current_state).q_values
        self.all_action_scores = all_q_values.detach()
        q_values = torch.sum(all_q_values * learning_input.action,
                             1,
                             keepdim=True)

        loss = self.q_network_loss(q_values, target_q_values)
        self.loss = loss.detach()

        self.q_network_optimizer.zero_grad()
        loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(current_state).q_values
        reward_loss = F.mse_loss(reward_estimates, rewards)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            model_values_on_logged_actions=q_values,
        )