コード例 #1
0
    def extract(self, ws, input_record, extract_record):
        def fetch(b):
            data = ws.fetch_blob(str(b()))
            return torch.tensor(data)

        state = mt.FeatureVector(float_features=fetch(extract_record.state))
        if self.sorted_action_features is None:
            action = None
        else:
            action = mt.FeatureVector(
                float_features=fetch(extract_record.action))
        return mt.StateAction(state=state, action=action)
コード例 #2
0
 def input_prototype(self):
     if self.parametric_action:
         return rlt.StateAction(
             state=rlt.FeatureVector(float_features=torch.randn(1, self.state_dim)),
             action=rlt.FeatureVector(
                 float_features=torch.randn(1, self.action_dim)
             ),
         )
     else:
         return rlt.StateInput(
             state=rlt.FeatureVector(float_features=torch.randn(1, self.state_dim))
         )
コード例 #3
0
 def internal_reward_estimation(self, state, action):
     """
     Only used by Gym
     """
     self.reward_network.eval()
     reward_estimates = self.reward_network(
         rlt.StateAction(
             state=rlt.FeatureVector(float_features=state),
             action=rlt.FeatureVector(float_features=action),
         ))
     self.reward_network.train()
     return reward_estimates.q_value.cpu()
コード例 #4
0
 def get_detached_q_values(
         self, state,
         action) -> Tuple[rlt.SingleQValue, Optional[rlt.SingleQValue]]:
     """ Gets the q values from the model and target networks """
     with torch.no_grad():
         input = rlt.StateAction(state=state, action=action)
         q_values = self.q_network(input)
         if self.double_q_learning:
             q_values_target = self.q_network_target(input)
         else:
             q_values_target = None
     return q_values, q_values_target
コード例 #5
0
 def internal_reward_estimation(self, state, action):
     """
     Only used by Gym
     """
     self.reward_network.eval()
     with torch.no_grad():
         state = torch.from_numpy(np.array(state)).type(self.dtype)
         action = torch.from_numpy(np.array(action)).type(self.dtype)
         reward_estimates = self.reward_network(
             rlt.StateAction(
                 state=rlt.FeatureVector(float_features=state),
                 action=rlt.FeatureVector(float_features=action),
             ))
     self.reward_network.train()
     return reward_estimates.q_value.cpu().data.numpy()
コード例 #6
0
    def get_max_q_values(self, tiled_next_state, possible_next_actions,
                         double_q_learning):
        """
        :param double_q_learning: bool to use double q-learning
        """

        lengths = possible_next_actions.lengths
        row_nums = np.arange(len(lengths))
        row_idxs = np.repeat(row_nums, lengths.cpu().numpy())
        col_idxs = arange_expand(lengths).cpu().numpy()

        dense_idxs = torch.tensor((row_idxs, col_idxs),
                                  device=lengths.device,
                                  dtype=torch.int64)

        q_network_input = rlt.StateAction(state=tiled_next_state,
                                          action=possible_next_actions.actions)
        if double_q_learning:
            q_values = self.q_network(
                q_network_input).q_value.squeeze().detach()
            q_values_target = (self.q_network_target(
                q_network_input).q_value.squeeze().detach())
        else:
            q_values = self.q_network_target(
                q_network_input).q_value.squeeze().detach()

        dense_dim = [len(lengths), max(lengths)]
        # Add specific fingerprint to q-values so that after sparse -> dense we can
        # subtract the fingerprint to identify the 0's added in sparse -> dense
        q_values.add_(self.FINGERPRINT)
        sparse_q = torch.sparse_coo_tensor(dense_idxs, q_values, dense_dim)
        dense_q = sparse_q.to_dense()
        dense_q.add_(self.FINGERPRINT * -1)
        dense_q[dense_q == self.FINGERPRINT *
                -1] = self.ACTION_NOT_POSSIBLE_VAL
        max_q_values, max_indexes = torch.max(dense_q, dim=1)

        if double_q_learning:
            sparse_q_target = torch.sparse_coo_tensor(dense_idxs,
                                                      q_values_target,
                                                      dense_dim)
            dense_q_values_target = sparse_q_target.to_dense()
            max_q_values = torch.gather(dense_q_values_target, 1,
                                        max_indexes.unsqueeze(1))

        return max_q_values.squeeze()
コード例 #7
0
    def create_from_tensors(
        cls,
        trainer: RLTrainer,
        mdp_ids: np.ndarray,
        sequence_numbers: torch.Tensor,
        states: Union[mt.State, torch.Tensor],
        actions: Union[mt.Action, torch.Tensor],
        propensities: torch.Tensor,
        rewards: torch.Tensor,
        possible_actions_mask: torch.Tensor,
        possible_actions: Optional[mt.FeatureVector] = None,
        max_num_actions: Optional[int] = None,
        metrics: Optional[torch.Tensor] = None,
    ):
        with torch.no_grad():
            # Switch to evaluation mode for the network
            old_q_train_state = trainer.q_network.training
            old_reward_train_state = trainer.reward_network.training
            trainer.q_network.train(False)
            trainer.reward_network.train(False)

            if max_num_actions:
                # Parametric model CPE
                state_action_pairs = mt.StateAction(state=states, action=actions)
                tiled_state = mt.FeatureVector(
                    states.float_features.repeat(1, max_num_actions).reshape(
                        -1, states.float_features.shape[1]
                    )
                )
                # Get Q-value of action taken
                possible_actions_state_concat = mt.StateAction(
                    state=tiled_state, action=possible_actions
                )

                # Parametric actions
                model_values = trainer.q_network(possible_actions_state_concat).q_value
                assert (
                    model_values.shape[0] * model_values.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values = model_values.reshape(possible_actions_mask.shape)

                model_rewards = trainer.reward_network(
                    possible_actions_state_concat
                ).q_value
                assert (
                    model_rewards.shape[0] * model_rewards.shape[1]
                    == possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
                ), (
                    "Invalid shapes: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_rewards = model_rewards.reshape(possible_actions_mask.shape)

                model_values_for_logged_action = trainer.q_network(
                    state_action_pairs
                ).q_value
                model_rewards_for_logged_action = trainer.reward_network(
                    state_action_pairs
                ).q_value

                action_mask = (
                    torch.abs(model_values - model_values_for_logged_action) < 1e-3
                ).float()

                model_metrics = None
                model_metrics_for_logged_action = None
                model_metrics_values = None
                model_metrics_values_for_logged_action = None
            else:
                action_mask = actions.float()

                # Switch to evaluation mode for the network
                old_q_cpe_train_state = trainer.q_network_cpe.training
                trainer.q_network_cpe.train(False)

                # Discrete actions
                rewards = trainer.boost_rewards(rewards, actions)
                model_values = trainer.get_detached_q_values(states)[0]
                assert model_values.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(actions.shape)
                )
                assert model_values.shape == possible_actions_mask.shape, (
                    "Invalid shape: "
                    + str(model_values.shape)
                    + " != "
                    + str(possible_actions_mask.shape)
                )
                model_values_for_logged_action = torch.sum(
                    model_values * action_mask, dim=1, keepdim=True
                )

                if isinstance(states, mt.State):
                    states = mt.StateInput(state=states)

                rewards_and_metric_rewards = trainer.reward_network(states)

                # In case we reuse the modular for Q-network
                if hasattr(rewards_and_metric_rewards, "q_values"):
                    rewards_and_metric_rewards = rewards_and_metric_rewards.q_values

                num_actions = trainer.num_actions

                model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
                assert model_rewards.shape == actions.shape, (
                    "Invalid shape: "
                    + str(model_rewards.shape)
                    + " != "
                    + str(actions.shape)
                )
                model_rewards_for_logged_action = torch.sum(
                    model_rewards * action_mask, dim=1, keepdim=True
                )

                model_metrics = rewards_and_metric_rewards[:, num_actions:]

                assert model_metrics.shape[1] % num_actions == 0, (
                    "Invalid metrics shape: "
                    + str(model_metrics.shape)
                    + " "
                    + str(num_actions)
                )
                num_metrics = model_metrics.shape[1] // num_actions

                if num_metrics == 0:
                    model_metrics_values = None
                    model_metrics_for_logged_action = None
                    model_metrics_values_for_logged_action = None
                else:
                    model_metrics_values = trainer.q_network_cpe(states)
                    # Backward compatility
                    if hasattr(model_metrics_values, "q_values"):
                        model_metrics_values = model_metrics_values.q_values
                    model_metrics_values = model_metrics_values[:, num_actions:]
                    assert model_metrics_values.shape[1] == num_actions * num_metrics, (
                        "Invalid shape: "
                        + str(model_metrics_values.shape[1])
                        + " != "
                        + str(actions.shape[1] * num_metrics)
                    )

                    model_metrics_for_logged_action_list = []
                    model_metrics_values_for_logged_action_list = []
                    for metric_index in range(num_metrics):
                        metric_start = metric_index * num_actions
                        metric_end = (metric_index + 1) * num_actions
                        model_metrics_for_logged_action_list.append(
                            torch.sum(
                                model_metrics[:, metric_start:metric_end] * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )

                        model_metrics_values_for_logged_action_list.append(
                            torch.sum(
                                model_metrics_values[:, metric_start:metric_end]
                                * action_mask,
                                dim=1,
                                keepdim=True,
                            )
                        )
                    model_metrics_for_logged_action = torch.cat(
                        model_metrics_for_logged_action_list, dim=1
                    )
                    model_metrics_values_for_logged_action = torch.cat(
                        model_metrics_values_for_logged_action_list, dim=1
                    )

                # Switch back to the old mode
                trainer.q_network_cpe.train(old_q_cpe_train_state)

            # Switch back to the old mode
            trainer.q_network.train(old_q_train_state)
            trainer.reward_network.train(old_reward_train_state)

            return cls(
                mdp_id=mdp_ids,
                sequence_number=sequence_numbers,
                logged_propensities=propensities,
                logged_rewards=rewards,
                action_mask=action_mask,
                model_rewards=model_rewards,
                model_rewards_for_logged_action=model_rewards_for_logged_action,
                model_values=model_values,
                model_values_for_logged_action=model_values_for_logged_action,
                model_metrics_values=model_metrics_values,
                model_metrics_values_for_logged_action=model_metrics_values_for_logged_action,
                model_propensities=masked_softmax(
                    model_values, possible_actions_mask, trainer.rl_temperature
                ),
                logged_metrics=metrics,
                model_metrics=model_metrics,
                model_metrics_for_logged_action=model_metrics_for_logged_action,
                # Will compute later
                logged_values=None,
                logged_metrics_values=None,
                possible_actions_mask=possible_actions_mask,
            )
コード例 #8
0
ファイル: ddpg_trainer.py プロジェクト: xavierzw/Horizon
 def input_prototype(self) -> rlt.StateAction:
     return rlt.StateAction(
         state=rlt.FeatureVector(float_features=torch.randn(1, self.state_dim)),
         action=rlt.FeatureVector(float_features=torch.randn(1, self.action_dim)),
     )
コード例 #9
0
    def train(self, training_batch, evaluator=None) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        reward = learning_input.reward
        discount_tensor = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            # TODO: Implement this in another diff
            raise NotImplementedError

        if self.maxq_learning:
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values = self.get_max_q_values(
                learning_input.tiled_next_state,
                learning_input.possible_next_actions,
                self.double_q_learning,
            )
        else:
            # SARSA
            next_q_values = self.get_next_action_q_values(
                learning_input.next_state, learning_input.next_action)

        filtered_max_q_vals = next_q_values.reshape(-1, 1) * not_done_mask

        if self.minibatch < self.reward_burnin:
            target_q_values = reward
        else:
            target_q_values = reward + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        current_state_action = rlt.StateAction(state=learning_input.state,
                                               action=learning_input.action)
        q_values = self.q_network(current_state_action).q_value
        self.all_action_scores = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        # TODO: Maybe soft_update should belong to the target network
        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(current_state_action).q_value
        reward_loss = F.mse_loss(reward_estimates, reward)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(td_loss=float(self.loss),
                                  reward_loss=float(reward_loss))

        if evaluator is not None:
            cpe_stats = BatchStatsForCPE(
                model_values_on_logged_actions=self.all_action_scores)
            evaluator.report(cpe_stats)
コード例 #10
0
    def train(self, training_batch) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        reward = learning_input.reward
        if self.multi_steps is not None:
            discount_tensor = torch.pow(self.gamma,
                                        learning_input.step.float())
        else:
            discount_tensor = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self.use_seq_num_diff_as_time_diff:
            if self.multi_steps is not None:
                # TODO: Implement this in another diff
                pass
            else:
                discount_tensor = discount_tensor.pow(
                    learning_input.time_diff.float())

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                learning_input.tiled_next_state,
                learning_input.possible_next_actions)
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values.q_value,
                all_next_q_values_target.q_value,
                learning_input.possible_next_actions_mask.float(),
            )
        else:
            # SARSA (Use the target network)
            _, next_q_values = self.get_detached_q_values(
                learning_input.next_state, learning_input.next_action)
            next_q_values = next_q_values.q_value

        filtered_max_q_vals = next_q_values * not_done_mask.float()

        if self.minibatch < self.reward_burnin:
            target_q_values = reward
        else:
            target_q_values = reward + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        current_state_action = rlt.StateAction(state=learning_input.state,
                                               action=learning_input.action)
        q_values = self.q_network(current_state_action).q_value
        self.all_action_scores = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        if self.gradient_handler:
            self.gradient_handler(self.q_network.parameters())
        self.q_network_optimizer.step()

        # TODO: Maybe soft_update should belong to the target network
        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            # Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(current_state_action).q_value
        reward_loss = F.mse_loss(reward_estimates, reward)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            model_values_on_logged_actions=self.all_action_scores,
        )
コード例 #11
0
ファイル: ddpg_trainer.py プロジェクト: xavierzw/Horizon
    def train(self, training_batch: rlt.TrainingBatch) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state

        # As far as ddpg is concerned all actions are [-1, 1] due to actor tanh
        action = rlt.FeatureVector(
            rescale_torch_tensor(
                learning_input.action.float_features,
                new_min=self.min_action_range_tensor_training,
                new_max=self.max_action_range_tensor_training,
                prev_min=self.min_action_range_tensor_serving,
                prev_max=self.max_action_range_tensor_serving,
            )
        )

        rewards = learning_input.reward
        next_state = learning_input.next_state
        time_diffs = learning_input.time_diff
        discount_tensor = torch.full_like(rewards, self.gamma)
        not_done_mask = learning_input.not_terminal

        # Optimize the critic network subject to mean squared error:
        # L = ([r + gamma * Q(s2, a2)] - Q(s1, a1)) ^ 2
        q_s1_a1 = self.critic.forward(
            rlt.StateAction(state=state, action=action)
        ).q_value
        next_action = rlt.FeatureVector(
            float_features=self.actor_target(
                rlt.StateAction(state=next_state, action=None)
            ).action
        )

        q_s2_a2 = self.critic_target.forward(
            rlt.StateAction(state=next_state, action=next_action)
        ).q_value
        filtered_q_s2_a2 = not_done_mask.float() * q_s2_a2

        if self.use_seq_num_diff_as_time_diff:
            discount_tensor = discount_tensor.pow(time_diffs)

        target_q_values = rewards + (discount_tensor * filtered_q_s2_a2)

        # compute loss and update the critic network
        critic_predictions = q_s1_a1
        loss_critic = self.q_network_loss(critic_predictions, target_q_values.detach())
        loss_critic_for_eval = loss_critic.detach()
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Optimize the actor network subject to the following:
        # max mean(Q(s1, a1)) or min -mean(Q(s1, a1))
        actor_output = self.actor(rlt.StateAction(state=state, action=None))
        loss_actor = -(
            self.critic.forward(
                rlt.StateAction(
                    state=state,
                    action=rlt.FeatureVector(float_features=actor_output.action),
                )
            ).q_value.mean()
        )

        # Zero out both the actor and critic gradients because we need
        #   to backprop through the critic to get to the actor
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()

        # Use the soft update rule to update both target networks
        self._soft_update(self.actor, self.actor_target, self.tau)
        self._soft_update(self.critic, self.critic_target, self.tau)

        self.loss_reporter.report(
            td_loss=float(loss_critic_for_eval),
            reward_loss=None,
            model_values_on_logged_actions=critic_predictions,
        )
コード例 #12
0
def create_embed_rl_dataset(
    gym_env: OpenAIGymEnvironment,
    trainer: MDNRNNTrainer,
    dataset: RLDataset,
    use_gpu: bool = False,
    seq_len: int = 5,
    num_state_embed_episodes: int = 100,
    max_steps: Optional[int] = None,
    **kwargs,
):
    old_mdnrnn_mode = trainer.mdnrnn.mdnrnn.training
    trainer.mdnrnn.mdnrnn.eval()
    num_transitions = num_state_embed_episodes * max_steps  # type: ignore
    device = torch.device("cuda") if use_gpu else torch.device(
        "cpu")  # type: ignore

    (
        state_batch,
        action_batch,
        reward_batch,
        next_state_batch,
        next_action_batch,
        not_terminal_batch,
        step_batch,
        next_step_batch,
    ) = map(
        list,
        zip(*multi_step_sample_generator(
            gym_env=gym_env,
            num_transitions=num_transitions,
            max_steps=max_steps,
            # +1 because MDNRNN embeds the first seq_len steps and then
            # the embedded state will be concatenated with the last step
            multi_steps=seq_len + 1,
            include_shorter_samples_at_start=True,
            include_shorter_samples_at_end=False,
        )),
    )

    def concat_batch(batch):
        return torch.cat(
            [
                torch.tensor(np.expand_dims(x, axis=1),
                             dtype=torch.float,
                             device=device) for x in batch
            ],
            dim=1,
        )

    # shape: seq_len x batch_size x feature_dim
    mdnrnn_state = concat_batch(state_batch)
    next_mdnrnn_state = concat_batch(next_state_batch)
    mdnrnn_action = concat_batch(action_batch)
    next_mdnrnn_action = concat_batch(next_action_batch)

    mdnrnn_input = rlt.StateAction(
        state=rlt.FeatureVector(float_features=mdnrnn_state),
        action=rlt.FeatureVector(float_features=mdnrnn_action),
    )
    next_mdnrnn_input = rlt.StateAction(
        state=rlt.FeatureVector(float_features=next_mdnrnn_state),
        action=rlt.FeatureVector(float_features=next_mdnrnn_action),
    )
    # batch-compute state embedding
    mdnrnn_output = trainer.mdnrnn(mdnrnn_input)
    next_mdnrnn_output = trainer.mdnrnn(next_mdnrnn_input)

    for i in range(len(state_batch)):
        # Embed the state as the hidden layer's output
        # until the previous step + current state
        hidden_idx = 0 if step_batch[
            i] == 1 else step_batch[i] - 2  # type: ignore
        next_hidden_idx = next_step_batch[i] - 2  # type: ignore
        hidden_embed = (
            mdnrnn_output.all_steps_lstm_hidden[hidden_idx,
                                                i, :].squeeze().detach().cpu())
        state_embed = torch.cat(
            (hidden_embed, torch.tensor(state_batch[i][hidden_idx + 1])
             )  # type: ignore
        )
        next_hidden_embed = (next_mdnrnn_output.all_steps_lstm_hidden[
            next_hidden_idx, i, :].squeeze().detach().cpu())
        next_state_embed = torch.cat((
            next_hidden_embed,
            torch.tensor(next_state_batch[i][next_hidden_idx +
                                             1]),  # type: ignore
        ))

        logger.debug(
            "create_embed_rl_dataset:\nstate batch\n{}\naction batch\n{}\nlast "
            "action: {},reward: {}\nstate embed {}\nnext state embed {}\n".
            format(
                state_batch[i][:hidden_idx + 1],  # type: ignore
                action_batch[i][:hidden_idx + 1],  # type: ignore
                action_batch[i][hidden_idx + 1],  # type: ignore
                reward_batch[i][hidden_idx + 1],  # type: ignore
                state_embed,
                next_state_embed,
            ))

        terminal = 1 - not_terminal_batch[i][hidden_idx + 1]  # type: ignore
        possible_actions, possible_actions_mask = get_possible_actions(
            gym_env, ModelType.PYTORCH_PARAMETRIC_DQN.value, False)
        possible_next_actions, possible_next_actions_mask = get_possible_actions(
            gym_env, ModelType.PYTORCH_PARAMETRIC_DQN.value, terminal)
        dataset.insert(
            state=state_embed,
            action=torch.tensor(action_batch[i][hidden_idx +
                                                1]),  # type: ignore
            reward=reward_batch[i][hidden_idx + 1],  # type: ignore
            next_state=next_state_embed,
            next_action=torch.tensor(next_action_batch[i][next_hidden_idx +
                                                          1]  # type: ignore
                                     ),
            terminal=torch.tensor(terminal),
            possible_next_actions=possible_next_actions,
            possible_next_actions_mask=possible_next_actions_mask,
            time_diff=torch.tensor(1),
            possible_actions=possible_actions,
            possible_actions_mask=possible_actions_mask,
            policy_id=0,
        )
    logger.info("Insert {} transitions into a state embed dataset".format(
        len(state_batch)))
    trainer.mdnrnn.mdnrnn.train(old_mdnrnn_mode)
    return dataset
コード例 #13
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        next_state = learning_input.next_state
        reward = learning_input.reward
        not_done_mask = learning_input.not_terminal

        action = self._maybe_scale_action_in_train(action)

        # Compute current value estimates
        current_state_action = rlt.StateAction(state=state, action=action)
        q1_value = self.q1_network(current_state_action).q_value
        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
        actor_action = self.actor_network(rlt.StateInput(state=state)).action

        # Generate target = r + y * min (Q1(s',pi(s')), Q2(s',pi(s')))
        with torch.no_grad():
            next_actor = self.actor_network_target(
                rlt.StateInput(state=next_state)).action
            next_actor += (torch.randn_like(next_actor) *
                           self.target_policy_smoothing).clamp(
                               -self.noise_clip, self.noise_clip)
            next_actor = torch.max(
                torch.min(next_actor, self.max_action_range_tensor_training),
                self.min_action_range_tensor_training,
            )
            next_state_actor = rlt.StateAction(
                state=next_state,
                action=rlt.FeatureVector(float_features=next_actor))
            next_state_value = self.q1_network_target(next_state_actor).q_value

            if self.q2_network is not None:
                next_state_value = torch.min(
                    next_state_value,
                    self.q2_network_target(next_state_actor).q_value)

            target_q_value = (
                reward + self.gamma * next_state_value * not_done_mask.float())

        # Optimize Q1 and Q2
        q1_loss = F.mse_loss(q1_value, target_q_value)
        q1_loss.backward()
        self._maybe_run_optimizer(self.q1_network_optimizer,
                                  self.minibatches_per_step)
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            q2_loss.backward()
            self._maybe_run_optimizer(self.q2_network_optimizer,
                                      self.minibatches_per_step)

        # Only update actor and target networks after a fixed number of Q updates
        if self.minibatch % self.delayed_policy_update == 0:
            actor_loss = -self.q1_network(
                rlt.StateAction(
                    state=state,
                    action=rlt.FeatureVector(
                        float_features=actor_action))).q_value.mean()
            actor_loss.backward()
            self._maybe_run_optimizer(self.actor_network_optimizer,
                                      self.minibatches_per_step)

            # Use the soft update rule to update the target networks
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            self._maybe_soft_update(
                self.actor_network,
                self.actor_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq is not None
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
        )
コード例 #14
0
ファイル: mdnrnn_trainer.py プロジェクト: xaxis-code/Horizon
    def get_loss(
        self,
        training_batch: rlt.TrainingBatch,
        state_dim: Optional[int] = None,
        batch_first: bool = False,
    ):
        """ Compute losses.

        The loss that is computed is:
        (GMMLoss(next_state, GMMPredicted) + MSE(reward, predicted_reward) +
             BCE(not_terminal, logit_not_terminal)) / (STATE_DIM + 2)
        The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
        approximately linearily with STATE_DIM, the feature size of states. All losses
        are averaged both on the batch and the sequence dimensions (the two first
        dimensions).

        :param training_batch
        training_batch.learning_input has these fields:
            state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor
            reward: (BATCH_SIZE, SEQ_LEN) torch tensor
            not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor
            next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
        :param state_dim: the dimension of states. If provided, use it to normalize loss
        :param batch_first: whether data's first dimension represents batch size. If
            FALSE, state, action, reward, not-terminal, and next_state's first
            two dimensions are SEQ_LEN and BATCH_SIZE.

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        learning_input = training_batch.training_input
        # mdnrnn's input should have seq_len as the first dimension
        if batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                learning_input.state.float_features,
                learning_input.action.float_features,
                learning_input.next_state,
                learning_input.reward,
                learning_input.not_terminal,
            )
            learning_input = rlt.MemoryNetworkInput(
                state=rlt.FeatureVector(float_features=state),
                action=rlt.FeatureVector(float_features=action),
                next_state=next_state,
                reward=reward,
                not_terminal=not_terminal,
            )

        mdnrnn_input = rlt.StateAction(state=learning_input.state,
                                       action=learning_input.action)
        mdnrnn_output = self.mdnrnn(mdnrnn_input)
        mus, sigmas, logpi, rs, ds = (
            mdnrnn_output.mus,
            mdnrnn_output.sigmas,
            mdnrnn_output.logpi,
            mdnrnn_output.reward,
            mdnrnn_output.not_terminal,
        )

        gmm = gmm_loss(learning_input.next_state, mus, sigmas, logpi)
        bce = F.binary_cross_entropy_with_logits(ds,
                                                 learning_input.not_terminal)
        mse = F.mse_loss(rs, learning_input.reward)
        if state_dim is not None:
            loss = (gmm + bce + mse) / (state_dim + 2)
        else:
            loss = mse + bce + gmm
        return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
コード例 #15
0
ファイル: mdnrnn_trainer.py プロジェクト: Fengdan92/Horizon
    def get_loss(
        self,
        training_batch: rlt.TrainingBatch,
        state_dim: Optional[int] = None,
        batch_first: bool = False,
    ):
        """
        Compute losses.

        The loss that is computed is:
            (GMMLoss(next_state, GMMPredicted) + MSE(reward, predicted_reward) +
            BCE(not_terminal, logit_not_terminal)) / (STATE_DIM + 2)

        The STATE_DIM + 2 factor is here to counteract the fact that the GMMLoss scales
            approximately linearily with STATE_DIM, the feature size of states. All losses
            are averaged both on the batch and the sequence dimensions (the two first
            dimensions).

        :param training_batch
            training_batch.learning_input has these fields:
            - state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            - action: (BATCH_SIZE, SEQ_LEN, ACTION_DIM) torch tensor
            - reward: (BATCH_SIZE, SEQ_LEN) torch tensor
            - not-terminal: (BATCH_SIZE, SEQ_LEN) torch tensor
            - next_state: (BATCH_SIZE, SEQ_LEN, STATE_DIM) torch tensor
            the first two dimensions may be swapped depending on batch_first

        :param state_dim: the dimension of states. If provided, use it to normalize
            gmm loss

        :param batch_first: whether data's first dimension represents batch size. If
            FALSE, state, action, reward, not-terminal, and next_state's first
            two dimensions are SEQ_LEN and BATCH_SIZE.

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """
        learning_input = training_batch.training_input
        # mdnrnn's input should have seq_len as the first dimension
        if batch_first:
            state, action, next_state, reward, not_terminal = transpose(
                learning_input.state.float_features,
                learning_input.action.float_features,  # type: ignore
                learning_input.next_state,
                learning_input.reward,
                learning_input.not_terminal,  # type: ignore
            )
            learning_input = rlt.MemoryNetworkInput(  # type: ignore
                state=rlt.FeatureVector(float_features=state),
                reward=reward,
                time_diff=torch.ones_like(reward).float(),
                action=rlt.FeatureVector(float_features=action),
                not_terminal=not_terminal,
                next_state=next_state,
            )

        mdnrnn_input = rlt.StateAction(
            state=learning_input.state,
            action=learning_input.action  # type: ignore
        )
        mdnrnn_output = self.mdnrnn(mdnrnn_input)
        mus, sigmas, logpi, rs, nts = (
            mdnrnn_output.mus,
            mdnrnn_output.sigmas,
            mdnrnn_output.logpi,
            mdnrnn_output.reward,
            mdnrnn_output.not_terminal,
        )

        next_state = learning_input.next_state
        not_terminal = learning_input.not_terminal  # type: ignore
        reward = learning_input.reward
        if self.params.fit_only_one_next_step:
            next_state, not_terminal, reward, mus, sigmas, logpi, nts, rs = tuple(
                map(
                    lambda x: x[-1:],
                    (next_state, not_terminal, reward, mus, sigmas, logpi, nts,
                     rs),
                ))

        gmm = (gmm_loss(next_state, mus, sigmas, logpi) *
               self.params.next_state_loss_weight)
        bce = (F.binary_cross_entropy_with_logits(nts, not_terminal) *
               self.params.not_terminal_loss_weight)
        mse = F.mse_loss(rs, reward) * self.params.reward_loss_weight
        if state_dim is not None:
            loss = gmm / (state_dim + 2) + bce + mse
        else:
            loss = gmm + bce + mse
        return {"gmm": gmm, "bce": bce, "mse": mse, "loss": loss}
コード例 #16
0
 def get_next_action_q_values(self, state, action):
     return self.q_network_target(
         rlt.StateAction(state=state, action=action)).q_value
コード例 #17
0
    def train(self, training_batch) -> None:
        if isinstance(training_batch, TrainingDataPage):
            if self.maxq_learning:
                training_batch = training_batch.as_parametric_maxq_training_batch()
            else:
                training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        reward = learning_input.reward
        not_done_mask = learning_input.not_terminal

        discount_tensor = torch.full_like(reward, self.gamma)
        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma, learning_input.time_diff.float())
        if self.multi_steps is not None:
            discount_tensor = torch.pow(self.gamma, learning_input.step.float())

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                learning_input.tiled_next_state, learning_input.possible_next_actions
            )
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values.q_value,
                all_next_q_values_target.q_value,
                learning_input.possible_next_actions_mask.float(),
            )
        else:
            # SARSA (Use the target network)
            _, next_q_values = self.get_detached_q_values(
                learning_input.next_state, learning_input.next_action
            )
            next_q_values = next_q_values.q_value

        filtered_max_q_vals = next_q_values * not_done_mask.float()

        target_q_values = reward + (discount_tensor * filtered_max_q_vals)

        # Get Q-value of action taken
        current_state_action = rlt.StateAction(
            state=learning_input.state, action=learning_input.action
        )
        q_values = self.q_network(current_state_action).q_value
        self.all_action_scores = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        self.q_network_optimizer.step()

        # Use the soft update rule to update target network
        self._soft_update(self.q_network, self.q_network_target, self.tau)

        # get reward estimates
        reward_estimates = self.reward_network(current_state_action).q_value
        reward_loss = F.mse_loss(reward_estimates, reward)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_network_optimizer.step()

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            model_values_on_logged_actions=self.all_action_scores,
        )
コード例 #18
0
    def train(self, training_batch, evaluator=None) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = rlt.FeatureVector(
                rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                ))

        current_state_action = rlt.StateAction(state=state, action=action)

        q1_value = self.q1_network(current_state_action).q_value
        min_q_value = q1_value

        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
            min_q_value = torch.min(q1_value, q2_value)

        # Use the minimum as target, ensure no gradient going through
        min_q_value = min_q_value.detach()

        #
        # First, optimize value network; minimizing MSE between
        # V(s) & Q(s, a) - log(pi(a|s))
        #

        state_value = self.value_network(state.float_features)  # .q_value

        if self.logged_action_uniform_prior:
            log_prob_a = torch.zeros_like(min_q_value)
            target_value = min_q_value
        else:
            with torch.no_grad():
                log_prob_a = self.actor_network.get_log_prob(
                    state, action.float_features)
                log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        self.value_network_optimizer.zero_grad()
        value_loss.backward()
        self.value_network_optimizer.step()

        #
        # Second, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        with torch.no_grad():
            next_state_value = (self.value_network_target(
                learning_input.next_state.float_features) * not_done_mask)

            if self.minibatch < self.reward_burnin:
                target_q_value = reward
            else:
                target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        self.q1_network_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_network_optimizer.step()
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            self.q2_network_optimizer.zero_grad()
            q2_loss.backward()
            self.q2_network_optimizer.step()

        #
        # Lastly, optimize the actor; minimizing KL-divergence between action propensity
        # & softmax of value. Due to reparameterization trick, it ends up being
        # log_prob(actor_action) - Q(s, actor_action)
        #

        actor_output = self.actor_network(rlt.StateInput(state=state))

        state_actor_action = rlt.StateAction(
            state=state,
            action=rlt.FeatureVector(float_features=actor_output.action))
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = (self.entropy_temperature * actor_output.log_prob -
                      min_q_actor_value)
        # Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()
        self.actor_network_optimizer.zero_grad()
        actor_loss_mean.backward()
        self.actor_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.value_network, self.value_network_target,
                              1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.value_network, self.value_network_target,
                              self.tau)

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq is not None
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            SummaryWriterContext.add_histogram("value_network/target",
                                               target_value)
            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/min_q_actor_value",
                                               min_q_actor_value)
            SummaryWriterContext.add_histogram("actor/action_log_prob",
                                               actor_output.log_prob)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        if evaluator is not None:
            cpe_stats = BatchStatsForCPE(
                td_loss=q1_loss.detach().cpu().numpy(),
                logged_rewards=reward.detach().cpu().numpy(),
                model_values_on_logged_actions=q1_value.detach().cpu().numpy(),
                model_propensities=actor_output.log_prob.exp().detach().cpu().
                numpy(),
                model_values=min_q_actor_value.detach().cpu().numpy(),
            )
            evaluator.report(cpe_stats)
コード例 #19
0
 def forward(self, input):
     preprocessed_state = self.state_preprocessor(input.state)
     preprocessed_action = self.action_preprocessor(input.action)
     return self.q_network(
         rlt.StateAction(state=preprocessed_state,
                         action=preprocessed_action))
コード例 #20
0
 def input_prototype(self):
     return rlt.StateAction(
         state=self.state_preprocessor.input_prototype(),
         action=self.action_preprocessor.input_prototype(),
     )
コード例 #21
0
    def train(self, training_batch, evaluator=None) -> None:
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        s = learning_input.state
        a = learning_input.action.float_features
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        current_state_action = rlt.StateAction(
            state=learning_input.state, action=learning_input.action
        )

        q1_value = self.q1_network(current_state_action).q_value
        min_q_value = q1_value

        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
            min_q_value = torch.min(q1_value, q2_value)

        # Use the minimum as target, ensure no gradient going through
        min_q_value = min_q_value.detach()

        #
        # First, optimize value network; minimizing MSE between
        # V(s) & Q(s, a) - log(pi(a|s))
        #

        state_value = self.value_network(s.float_features)  # .q_value

        with torch.no_grad():
            log_prob_a = self.actor_network.get_log_prob(s, a)
            target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        self.value_network_optimizer.zero_grad()
        value_loss.backward()
        self.value_network_optimizer.step()

        #
        # Second, optimize Q networks; minimizing MSE between
        # Q(s, a) & r + discount * V'(next_s)
        #

        with torch.no_grad():
            next_state_value = (
                self.value_network_target(learning_input.next_state.float_features)
                * not_done_mask
            )

            if self.minibatch < self.reward_burnin:
                target_q_value = reward
            else:
                target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        self.q1_network_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_network_optimizer.step()
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            self.q2_network_optimizer.zero_grad()
            q2_loss.backward()
            self.q2_network_optimizer.step()

        #
        # Lastly, optimize the actor; minimizing KL-divergence between action propensity
        # & softmax of value. Due to reparameterization trick, it ends up being
        # log_prob(actor_action) - Q(s, actor_action)
        #

        actor_output = self.actor_network(rlt.StateInput(state=learning_input.state))

        state_actor_action = rlt.StateAction(
            state=s, action=rlt.FeatureVector(float_features=actor_output.action)
        )
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = torch.mean(
            self.entropy_temperature * actor_output.log_prob - min_q_actor_value
        )
        self.actor_network_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_network_optimizer.step()

        if self.minibatch < self.reward_burnin:
            # Reward burnin: force target network
            self._soft_update(self.value_network, self.value_network_target, 1.0)
        else:
            # Use the soft update rule to update both target networks
            self._soft_update(self.value_network, self.value_network_target, self.tau)

        if evaluator is not None:
            # FIXME
            self.evaluate(evaluator)
コード例 #22
0
    def train(self, training_batch) -> None:
        """
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        """
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch(
            )

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = rlt.FeatureVector(
                rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                ))

        with torch.enable_grad():
            #
            # First, optimize Q networks; minimizing MSE between
            # Q(s, a) & r + discount * V'(next_s)
            #

            current_state_action = rlt.StateAction(state=state, action=action)
            q1_value = self.q1_network(current_state_action).q_value
            if self.q2_network:
                q2_value = self.q2_network(current_state_action).q_value
            actor_output = self.actor_network(rlt.StateInput(state=state))

            # Optimize Alpha
            if self.alpha_optimizer is not None:
                alpha_loss = -(self.log_alpha *
                               (actor_output.log_prob +
                                self.target_entropy).detach()).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.entropy_temperature = self.log_alpha.exp()

            with torch.no_grad():
                if self.value_network is not None:
                    next_state_value = self.value_network_target(
                        learning_input.next_state.float_features)
                else:
                    next_state_actor_output = self.actor_network(
                        rlt.StateInput(state=learning_input.next_state))
                    next_state_actor_action = rlt.StateAction(
                        state=learning_input.next_state,
                        action=rlt.FeatureVector(
                            float_features=next_state_actor_output.action),
                    )
                    next_state_value = self.q1_network_target(
                        next_state_actor_action).q_value

                    if self.q2_network is not None:
                        target_q2_value = self.q2_network_target(
                            next_state_actor_action).q_value
                        next_state_value = torch.min(next_state_value,
                                                     target_q2_value)

                    log_prob_a = self.actor_network.get_log_prob(
                        learning_input.next_state,
                        next_state_actor_output.action)
                    log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                    next_state_value -= self.entropy_temperature * log_prob_a

                target_q_value = (
                    reward +
                    discount * next_state_value * not_done_mask.float())

            q1_loss = F.mse_loss(q1_value, target_q_value)
            q1_loss.backward()
            self._maybe_run_optimizer(self.q1_network_optimizer,
                                      self.minibatches_per_step)
            if self.q2_network:
                q2_loss = F.mse_loss(q2_value, target_q_value)
                q2_loss.backward()
                self._maybe_run_optimizer(self.q2_network_optimizer,
                                          self.minibatches_per_step)

            #
            # Second, optimize the actor; minimizing KL-divergence between action propensity
            # & softmax of value. Due to reparameterization trick, it ends up being
            # log_prob(actor_action) - Q(s, actor_action)
            #

            state_actor_action = rlt.StateAction(
                state=state,
                action=rlt.FeatureVector(float_features=actor_output.action),
            )
            q1_actor_value = self.q1_network(state_actor_action).q_value
            min_q_actor_value = q1_actor_value
            if self.q2_network:
                q2_actor_value = self.q2_network(state_actor_action).q_value
                min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

            actor_loss = (self.entropy_temperature * actor_output.log_prob -
                          min_q_actor_value)
            # Do this in 2 steps so we can log histogram of actor loss
            actor_loss_mean = actor_loss.mean()
            actor_loss_mean.backward()
            self._maybe_run_optimizer(self.actor_network_optimizer,
                                      self.minibatches_per_step)

            #
            # Lastly, if applicable, optimize value network; minimizing MSE between
            # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
            #

            if self.value_network is not None:
                state_value = self.value_network(state.float_features)

                if self.logged_action_uniform_prior:
                    log_prob_a = torch.zeros_like(min_q_actor_value)
                    target_value = min_q_actor_value
                else:
                    with torch.no_grad():
                        log_prob_a = actor_output.log_prob
                        log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                        target_value = (min_q_actor_value -
                                        self.entropy_temperature * log_prob_a)

                value_loss = F.mse_loss(state_value, target_value.detach())
                value_loss.backward()
                self._maybe_run_optimizer(self.value_network_optimizer,
                                          self.minibatches_per_step)

        # Use the soft update rule to update the target networks
        if self.value_network is not None:
            self._maybe_soft_update(
                self.value_network,
                self.value_network_target,
                self.tau,
                self.minibatches_per_step,
            )
        else:
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (self.tensorboard_logging_freq is not None
                and self.minibatch % self.tensorboard_logging_freq == 0):
            SummaryWriterContext.add_histogram("q1/logged_state_value",
                                               q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value",
                                                   q2_value)

            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            if self.value_network:
                SummaryWriterContext.add_histogram("value_network/target",
                                                   target_value)

            SummaryWriterContext.add_histogram("q_network/next_state_value",
                                               next_state_value)
            SummaryWriterContext.add_histogram("q_network/target_q_value",
                                               target_q_value)
            SummaryWriterContext.add_histogram("actor/min_q_actor_value",
                                               min_q_actor_value)
            SummaryWriterContext.add_histogram("actor/action_log_prob",
                                               actor_output.log_prob)
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )