コード例 #1
0
    def test_log_prob_gradient(self):
        """
        Same thing. Tanh term drops out since tanh has no params
        d/d mu log f_X(x) = - 2 (mu - x)
        d/d sigma log f_X(x) = 1/sigma^3 - 1/sigma
        :return:
        """
        mean_var = ptu.from_numpy(np.array([0]), requires_grad=True)
        std_var = ptu.from_numpy(np.array([0.25]), requires_grad=True)
        tanh_normal = TanhNormal(mean_var, std_var)
        z = ptu.from_numpy(np.array([1]))
        x = torch.tanh(z)
        log_prob = tanh_normal.log_prob(x, pre_tanh_value=z)

        gradient = ptu.from_numpy(np.array([1]))

        log_prob.backward(gradient)

        self.assertNpArraysEqual(
            ptu.get_numpy(mean_var.grad),
            np.array([16]),
        )
        self.assertNpArraysEqual(
            ptu.get_numpy(std_var.grad),
            np.array([4**3 - 4]),
        )
コード例 #2
0
ファイル: policies.py プロジェクト: johndpope/DRL
    def forward(
        self,
        obs,
        deterministic=False,
        return_log_prob=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        """
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        expected_log_prob = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                action, pre_tanh_value = tanh_normal.sample(
                    return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                action = tanh_normal.sample()

        return (
            action,
            mean,
            log_std,
            log_prob,
            expected_log_prob,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
コード例 #3
0
    def test_log_prob_value(self):
        tanh_normal = TanhNormal(0, 1)
        z = np.array([1])
        x_np = np.tanh(z)
        x = ptu.from_numpy(x_np)
        log_prob = tanh_normal.log_prob(x)

        log_prob_np = ptu.get_numpy(log_prob)
        log_prob_expected = (
            np.log(np.array([1 / np.sqrt(2 * np.pi)])) - 0.5  # from Normal
            - np.log(1 - x_np**2))
        self.assertNpArraysEqual(
            log_prob_expected,
            log_prob_np,
        )
コード例 #4
0
ファイル: gaussian_policy.py プロジェクト: mihdalal/rlkit
    def forward(self, obs):
        h = self.obs_processor(obs)
        h = self.mean_and_log_std_net(h)
        mean, log_std = torch.split(h, self.action_dim, dim=1)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        tanh_normal = TanhNormal(mean, std)
        return tanh_normal
コード例 #5
0
    def log_prob(self, obs, actions):
        raw_actions = atanh(actions)
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        tanh_normal = TanhNormal(mean, std)
        log_prob = tanh_normal.log_prob(value=actions, pre_tanh_value=raw_actions)
        return log_prob.sum(-1)
コード例 #6
0
    def forward(
        self,
        obs,
        reparameterize=True,
        deterministic=False,
        return_log_prob=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        """
        h = obs
        h = super().forward(obs, return_features=True)
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                if reparameterize is True:
                    action, pre_tanh_value = tanh_normal.rsample(
                        return_pretanh_value=True)
                else:
                    action, pre_tanh_value = tanh_normal.sample(
                        return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()

        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
コード例 #7
0
    def forward(self, obs):
        h, h_aux = super().forward(obs, return_last_main_activations=True)
        mean = self.last_fc_main(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std

        tanh_normal = TanhNormal(mean, std)
        return tanh_normal, h_aux
コード例 #8
0
 def forward(self, data):
     x = self.gat_net(data)
     mean = self.last_mean_layer(x)
     if self.std is None:
         log_std = self.last_log_std(x)
         log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
         std = torch.exp(log_std)
     else:
         std = (torch.from_numpy(np.array([
             self.std,
         ])).float().to(ptu.device))
     return TanhNormal(mean, std)
コード例 #9
0
    def query_action_logpi(self,
                           obs,
                           action):
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
            if self.layer_norm and i < len(self.fcs):
                h = self.layer_norms[i](h)
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std

        pre_tanh_value = None
        tanh_normal = TanhNormal(mean, std)
        log_prob = tanh_normal.log_prob(
            action,
            pre_tanh_value=pre_tanh_value
        )
        log_prob = log_prob.sum(dim=1, keepdim=True)
        return log_prob
コード例 #10
0
ファイル: gaussian_policy.py プロジェクト: mihdalal/rlkit
    def forward(self, obs):
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = (torch.from_numpy(np.array([
                self.std,
            ])).float().to(ptu.device))

        return TanhNormal(mean, std)
コード例 #11
0
ファイル: policies.py プロジェクト: maivincent/rlkit
    def forward(
        self,
        obs,
        reparameterize=True,
        deterministic=False,
        return_log_prob=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        """

        # This is a bit messed up TODO clean it
        if obs.shape[0] == 1:  # if obs is single image: flatten
            h = torch.flatten(obs)
            h = h.view(1, -1)
        else:  # else if obs comes from replay buffer --> it is already flat and comes in a batch --> DO NOT flatten!
            h = obs

        h = super().forward(h, None, complete=False)

        mean = self.last_fc(h)

        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                if reparameterize is True:
                    action, pre_tanh_value = tanh_normal.rsample(
                        return_pretanh_value=True)
                else:
                    action, pre_tanh_value = tanh_normal.sample(
                        return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()
        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
コード例 #12
0
    def forward(
            self,
            obs,
            reparameterize=True,
            deterministic=False,
            return_log_prob=True,
            **kwargs
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        """
        # gt.stamp("Tanhnormal_forward")
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
            if self.layer_norm and i < len(self.fcs):
                h = self.layer_norms[i](h)
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)

            # gt.stamp("tanhnormal_pre")
            if return_log_prob:
                if reparameterize is True:
                    action, pre_tanh_value = tanh_normal.rsample(
                        return_pretanh_value=True
                    )
                else:
                    action, pre_tanh_value = tanh_normal.sample(
                        return_pretanh_value=True
                    )
                log_prob = tanh_normal.log_prob(
                    action,
                    pre_tanh_value=pre_tanh_value
                )
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()
            # gt.stamp("tanhnormal_post")

        log_std = log_std.sum(dim=1, keepdim=True)

        return (
            action, mean, log_std, log_prob, entropy, std,
            mean_action_log_prob, pre_tanh_value,
        )
コード例 #13
0
    def forward(
        self,
        obs,
        deterministic=False,
        return_log_prob=False,
        return_entropy=False,
        return_log_prob_of_mean=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        :param return_entropy: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        :param return_log_prob_of_mean: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        """
        obs, taus = split_tau(obs)
        batch_size = obs.size()[0]
        tau_vector = Variable(
            torch.zeros((batch_size, self.tau_vector_len)) + taus.data)
        h = obs
        h1 = self.hidden_activation(self.first_input(h))
        h2 = self.hidden_activation(self.second_input(tau_vector))
        h = torch.cat((h1, h2), dim=1)
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                action, pre_tanh_value = tanh_normal.sample(
                    return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                action = tanh_normal.sample()

        if return_entropy:
            entropy = log_std + 0.5 + np.log(2 * np.pi) / 2
            # Because tanh is invertible, the entropy of a Gaussian and the
            # entropy of the tanh of a Gaussian is the same.
            entropy = entropy.sum(dim=1, keepdim=True)
        if return_log_prob_of_mean:
            tanh_normal = TanhNormal(mean, std)
            mean_action_log_prob = tanh_normal.log_prob(
                torch.tanh(mean),
                pre_tanh_value=mean,
            )
            mean_action_log_prob = mean_action_log_prob.sum(dim=1,
                                                            keepdim=True)
        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
コード例 #14
0
    def forward(
        self,
        obs,
        reparameterize=True,
        deterministic=False,
        return_log_prob=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        """
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        out = self.last_fc(h).view(-1, self.k, (2 * self.action_dim + 1))

        log_w = out[..., 0]
        mean = out[..., 1:1 + self.action_dim]
        log_std = out[..., 1 + self.action_dim:]

        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        log_w = torch.clamp(log_w, min=LOG_W_MIN)

        self.log_w = log_w
        self.mean = mean
        self.log_std = log_std

        std = torch.exp(log_std)

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None

        arange = torch.arange(out.shape[0])

        if deterministic:
            ks = log_w.view(-1, self.k).argmax(1)
            action = torch.tanh(mean[arange, ks])
        else:
            sample_ks = Categorical(logits=log_w.view(-1, self.k)).sample()

            tanh_normal = TanhNormal(mean[arange, sample_ks], std[arange,
                                                                  sample_ks])
            if return_log_prob:
                action, pre_tanh_value = tanh_normal.sample(
                    return_pretanh_value=True)

                # (NxKxA), (NxKxA), (Nx1xA) => (NxK)
                log_p_xz_t = log_gaussian(mean, log_std,
                                          pre_tanh_value[:, None, :].data)

                log_p_x_t = torch.logsumexp(log_p_xz_t + log_w,
                                            1) - torch.logsumexp(log_w, 1)

                # squash correction
                log_prob = log_p_x_t - torch.log(1 - action**2 + 1e-6).sum(1)

                log_prob = log_prob[:, None]
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()

        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
コード例 #15
0
 def test_log_prob_type(self):
     tanh_normal = TanhNormal(0, 1)
     x = ptu.from_numpy(np.array([0]))
     log_prob = tanh_normal.log_prob(x)
     self.assertIsInstance(log_prob, torch.autograd.Variable)
コード例 #16
0
ファイル: policies.py プロジェクト: yufeiwang63/ROLL
    def forward(
        self,
        obs,
        reparameterize=True,
        deterministic=False,
        return_log_prob=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        """
        h = obs[:, :self.obs_preprocess_size]
        for i, fc in enumerate(self.obs_preprocess_fcs):
            h = fc(h)
            h = self.hidden_activation(h)
        obs_processed = self.obs_preprocess_last_fc(h)

        h = torch.cat([obs[:, self.obs_preprocess_size:], obs_processed],
                      dim=1)
        for i, fc in enumerate(self.fcs):
            h = fc(h)
            h = self.hidden_activation(h)
        mean = self.last_fc(h)

        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                if reparameterize is True:
                    action, pre_tanh_value = tanh_normal.rsample(
                        return_pretanh_value=True)
                else:
                    action, pre_tanh_value = tanh_normal.sample(
                        return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()

        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
コード例 #17
0
ファイル: gaussian_policy.py プロジェクト: mihdalal/rlkit
 def logprob(self, action, mean, std):
     tanh_normal = TanhNormal(mean, std)
     log_prob = tanh_normal.log_prob(action, )
     log_prob = log_prob.sum(dim=1, keepdim=True)
     return log_prob
コード例 #18
0
 def forward(self, x):
     mean, std, h_aux, _ = super().forward(x)
     tanh_normal = TanhNormal(mean, std)
     return tanh_normal, h_aux
コード例 #19
0
 def forward(self, *input):
     mean, log_std = super().forward(*input)
     std = log_std.exp()
     return TanhNormal(mean, std)
コード例 #20
0
    def forward(
        self,
        obs,
        reparameterize=True,
        deterministic=False,
        return_log_prob=False,
        return_entropy=False,
        return_log_prob_of_mean=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        :param return_entropy: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        :param return_log_prob_of_mean: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        """
        h = self.obs_processor(obs)
        h = self.mean_and_log_std_net(h)
        mean, log_std = torch.split(h, self.action_dim, dim=1)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                if reparameterize is True:
                    action, pre_tanh_value = tanh_normal.rsample(
                        return_pretanh_value=True)
                else:
                    action, pre_tanh_value = tanh_normal.sample(
                        return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()

        if return_entropy:
            entropy = log_std + 0.5 + np.log(2 * np.pi) / 2
            # I'm not sure how to compute the (differential) entropy for a
            # tanh(Gaussian)
            entropy = entropy.sum(dim=1, keepdim=True)
            raise NotImplementedError()
        if return_log_prob_of_mean:
            tanh_normal = TanhNormal(mean, std)
            mean_action_log_prob = tanh_normal.log_prob(
                torch.tanh(mean),
                pre_tanh_value=mean,
            )
            mean_action_log_prob = mean_action_log_prob.sum(dim=1,
                                                            keepdim=True)
        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
コード例 #21
0
    def forward(
        self,
        obs,
        reparameterize=True,
        deterministic=False,
        return_log_prob=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        """
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)  # actions * heads
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        log_stds = None
        log_probs = None

        if deterministic:
            means = mean.view(-1, self.action_dim)
            actions = torch.tanh(means)
            actions = actions.view(-1, self.heads, self.action_dim)
        else:
            all_actions = []
            means = mean.view(-1, self.action_dim)
            stds = std.view(-1, self.action_dim)
            log_stds = log_std.view(-1, self.action_dim)

            tanh_normal = TanhNormal(means, stds)
            if return_log_prob:
                if reparameterize is True:
                    action, pre_tanh_value = tanh_normal.rsample(
                        return_pretanh_value=True)
                else:
                    action, pre_tanh_value = tanh_normal.sample(
                        return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)

                log_probs = log_prob.view(-1, self.heads, self.action_dim)
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()

            actions = action.view(-1, self.heads, self.action_dim)

        return (
            actions,
            means,
            log_stds,
            log_probs,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
コード例 #22
0
    def train_from_torch(self, batch):
        obs = batch['observations']
        old_log_pi = batch['log_prob']
        advantage = batch['advantage']
        returns = batch['returns']
        actions = batch['actions']
        """
        Policy Loss
        """
        _, policy_mean, policy_log_std, _, _, policy_std, _, _ = self.policy(
            obs)
        new_log_pi = TanhNormal(policy_mean,
                                policy_std).log_prob(actions).sum(1,
                                                                  keepdim=True)

        # Advantage Clip
        ratio = torch.exp(new_log_pi - old_log_pi)
        left = ratio * advantage
        right = torch.clamp(ratio, 1 - self.epsilon,
                            1 + self.epsilon) * advantage

        policy_loss = (-1 * torch.min(left, right)).mean()
        """
        VF Loss
        """
        v_pred = self.vf(obs)
        v_target = returns
        vf_loss = self.vf_criterion(v_pred, v_target)
        """
        Update networks
        """
        loss = policy_loss + vf_loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.last_approx_kl is None or not self._need_to_update_eval_statistics:
            self.last_approx_kl = (old_log_pi - new_log_pi).detach()

        approx_ent = -new_log_pi
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            policy_grads = torch.cat(
                [p.grad.flatten() for p in self.policy.parameters()])
            value_grads = torch.cat(
                [p.grad.flatten() for p in self.vf.parameters()])

            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Predictions',
                    ptu.get_numpy(v_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Target',
                    ptu.get_numpy(v_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Gradients',
                    ptu.get_numpy(policy_grads),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Value Gradients',
                    ptu.get_numpy(value_grads),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy KL',
                    ptu.get_numpy(self.last_approx_kl),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Entropy',
                    ptu.get_numpy(approx_ent),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'New Log Pis',
                    ptu.get_numpy(new_log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Old Log Pis',
                    ptu.get_numpy(old_log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
        self._n_train_steps_total += 1