Ejemplo n.º 1
0
    def loss_function(self, forward_ret, labels=None):
        (x, t2, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu,
         pb_z1_b1_logvar, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu,
         pt_z2_z1_logvar, pd_x2_z2) = forward_ret

        # replicate x multiple times
        x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1,
                                -1)  # size: copy, bs, time, dim
        x2 = torch.gather(x, 2,
                          t2[..., None,
                             None].expand(-1, -1, -1,
                                          x.size(3))).view(-1, x.size(3))
        batch_size = x2.size(0)

        kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar,
                                           pb_z1_b1_mu,
                                           pb_z1_b1_logvar).mean()

        kl_shift_qb_pt = (
            ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) -
            ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar,
                                  qb_z2_b2)).mean()

        bce = F.binary_cross_entropy(pd_x2_z2, x2,
                                     reduction='sum') / batch_size
        bce_optimal = F.binary_cross_entropy(
            x2, x2, reduction='sum').detach() / batch_size
        bce_diff = bce - bce_optimal

        loss = bce_diff + kl_div_qs_pb + kl_shift_qb_pt

        return loss, bce_diff, kl_div_qs_pb, kl_shift_qb_pt, bce_optimal
Ejemplo n.º 2
0
    def loss_function(self, forward_ret, labels=None):
        (x, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pd_x2_z2) = forward_ret

        # replicate x multiple times
        x_flat = x.flatten(2, -1)
        x_flat = x_flat.expand(self.flags.samples_per_seq, -1, -1,
                               -1)  # size: copy, bs, time, dim
        batch_size = x.size(0)

        if self.adversarial and self.model.training:
            r_in = x.view(x.shape[0], x.shape[2], x.shape[3], x.shape[4])
            f_in = pd_x2_z2.view(x.shape[0], x.shape[2], x.shape[3],
                                 x.shape[4])
            for _ in range(self.d_steps):
                d_loss, g_loss, hidden_loss = self.dnet.get_loss(r_in, f_in)
                d_loss.backward(retain_graph=True)
                # print(d_loss, g_loss)
                self.adversarial_optim.step()
                self.adversarial_optim.zero_grad()
        else:
            g_loss = 0
            hidden_loss = 0

        eye = torch.ones(qb_z2_b2.size(-1)).to(
            qb_z2_b2.device)[None, None, :].expand(-1, qb_z2_b2.size(-2), -1)
        kl_div_qs_pb = ops.kl_div_gaussian(qb_z2_b2_mu, qb_z2_b2_logvar, 0,
                                           eye).mean()

        target = x.flatten()
        pred = pd_x2_z2.flatten()
        bce = F.binary_cross_entropy(pred, target,
                                     reduction='sum') / batch_size
        bce_optimal = F.binary_cross_entropy(target, target,
                                             reduction='sum') / batch_size
        bce_diff = bce - bce_optimal

        if self.adversarial and self.is_training():
            r_in = x.view(x.shape[0], x.shape[2], x.shape[3], x.shape[4])
            f_in = pd_x2_z2.view(x.shape[0], x.shape[2], x.shape[3],
                                 x.shape[4])
            for _ in range(self.d_steps):
                d_loss, g_loss, hidden_loss = self.dnet.get_loss(r_in, f_in)
                d_loss.backward(retain_graph=True)
                # print(d_loss, g_loss)
                self.adversarial_optim.step()
                self.adversarial_optim.zero_grad()
            bce_diff = hidden_loss  # XXX bce_diff added twice to loss?
        else:
            g_loss = 0
            hidden_loss = 0

        loss = bce_diff + hidden_loss + self.d_weight * g_loss + self.beta * kl_div_qs_pb

        return loss, bce_diff, kl_div_qs_pb, 0, bce_optimal
Ejemplo n.º 3
0
    def loss_function(self,
                      forward_ret,
                      labels=None,
                      loss=F.binary_cross_entropy):
        (x_orig, actions, rewards, done, t1, t2, qs_z1_z2_b1_mu,
         qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar, qb_z2_b2_mu,
         qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2,
         pd_g2_z2_mu, q1, q2) = forward_ret

        # replicate x multiple times
        x = x_orig.flatten(2, -1)
        x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1,
                                -1)  # size: copy, bs, time, dim
        x2 = torch.gather(x, 2,
                          t2[..., None,
                             None].expand(-1, -1, -1,
                                          x.size(3))).view(-1, x.size(3))
        kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar,
                                           pb_z1_b1_mu, pb_z1_b1_logvar)

        kl_shift_qb_pt = (
            ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) -
            ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar, qb_z2_b2))

        pd_x2_z2 = pd_x2_z2.flatten(1, -1)
        bce = loss(pd_x2_z2, x2, reduction='none').sum(dim=1)
        bce_optimal = loss(x2, x2, reduction='none').sum(dim=1)
        bce_diff = bce - bce_optimal

        if self.rl:
            # Note: x[t], rewards[t] is a result of actions[t]
            # Q(s[t], a[t+1]) = r[t+1] + γ max_a Q(s[t+1], a)
            returns, is_weight = labels

            # use pd_g2_z2_mu for returns modeling
            returns_loss = (pd_g2_z2_mu.squeeze(1) - (10.0 * returns))**2

            # reward clipping for Atari
            clipped_rewards = rewards.clamp(-1.0, 1.0)

            t1_next = t1 + 1
            t2_next = t2 + 1

            with torch.no_grad():
                # size: bs, action_space
                q1_next_target, q2_next_target = self.target_net.q_and_z_b(
                    x_orig, actions, rewards, done, t1_next, t2_next)[:2]
                q1_next_index, q2_next_index = self.model.q_and_z_b(
                    x_orig, actions, rewards, done, t1_next, t2_next)[:2]
                q1_next_index = torch.argmax(q1_next_index,
                                             dim=1,
                                             keepdim=True)
                q2_next_index = torch.argmax(q2_next_index,
                                             dim=1,
                                             keepdim=True)

            done = done[None, ...].expand(self.flags.samples_per_seq, -1,
                                          -1)  # size: copy, bs, time
            done1_next = torch.gather(done, 2,
                                      t1_next[..., None]).view(-1)  # size: bs
            done2_next = torch.gather(done, 2,
                                      t2_next[..., None]).view(-1)  # size: bs

            # size: copy, bs, time
            clipped_rewards = clipped_rewards[None, ...].expand(
                self.flags.samples_per_seq, -1, -1)
            r1_next = torch.gather(clipped_rewards, 2,
                                   t1_next[..., None]).view(-1)  # size: bs
            r2_next = torch.gather(clipped_rewards, 2,
                                   t2_next[..., None]).view(-1)  # size: bs

            actions = actions[None, ...].expand(self.flags.samples_per_seq, -1,
                                                -1)  # size: copy, bs, time
            a1_next = torch.gather(actions, 2,
                                   t1_next[..., None]).view(-1)  # size: bs
            a2_next = torch.gather(actions, 2,
                                   t2_next[..., None]).view(-1)  # size: bs

            pred_q1 = torch.gather(q1, 1, a1_next[..., None]).view(-1)
            pred_q2 = torch.gather(q2, 1, a2_next[..., None]).view(-1)

            q1_next = torch.gather(q1_next_target, 1, q1_next_index).view(-1)
            q2_next = torch.gather(q2_next_target, 1, q2_next_index).view(-1)
            target_q1 = r1_next + self.flags.discount_factor * (
                1.0 - done1_next) * q1_next
            target_q2 = r2_next + self.flags.discount_factor * (
                1.0 - done2_next) * q2_next

            rl_loss = 0.5 * (
                F.smooth_l1_loss(pred_q1, target_q1, reduction='none') +
                F.smooth_l1_loss(pred_q2, target_q2, reduction='none'))
            # errors for prioritized experience replay
            rl_errors = 0.5 * (torch.abs(pred_q1 - target_q1) +
                               torch.abs(pred_q2 - target_q2)).detach()
        else:
            returns_loss = 0.0
            rl_loss = 0.0
            is_weight = 1.0
            rl_errors = 0.0

        # multiply is_weight separately for ease of reporting
        returns_loss = is_weight * returns_loss
        bce_optimal = is_weight * bce_optimal
        bce_diff = is_weight * bce_diff
        kl_div_qs_pb = is_weight * kl_div_qs_pb
        kl_shift_qb_pt = is_weight * kl_shift_qb_pt
        rl_loss = is_weight * rl_loss

        beta = self.beta_decay.get_y(self.get_train_steps())
        tdvae_loss = bce_diff + returns_loss + beta * (kl_div_qs_pb +
                                                       kl_shift_qb_pt)
        loss = self.flags.tdvae_weight * tdvae_loss + self.flags.rl_weight * rl_loss

        if self.rl:  # workaround to work with non-RL setting
            rl_loss = rl_loss.mean()
            returns_loss = returns_loss.mean()
        return collections.OrderedDict([('loss', loss.mean()),
                                        ('bce_diff', bce_diff.mean()),
                                        ('returns_loss', returns_loss),
                                        ('kl_div_qs_pb', kl_div_qs_pb.mean()),
                                        ('kl_shift_qb_pt',
                                         kl_shift_qb_pt.mean()),
                                        ('rl_loss', rl_loss),
                                        ('bce_optimal', bce_optimal.mean()),
                                        ('rl_errors', rl_errors)])
Ejemplo n.º 4
0
    def loss_function(self,
                      forward_ret,
                      labels=None,
                      loss=F.binary_cross_entropy):
        (x_orig, actions, options, rewards, done, t1, t2, t_encodings,
         qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar,
         pb_z1_b1, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu,
         pt_z2_z1_logvar, pd_x2_z2, pd_g2_z2_mu, q1, q2, option_recon_loss,
         o_mean, o_logvar, option) = forward_ret

        # replicate x multiple times
        x = x_orig.flatten(2, -1)
        x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1,
                                -1)  # size: copy, bs, time, dim
        x2 = torch.gather(x, 2,
                          t2[..., None,
                             None].expand(-1, -1, -1,
                                          x.size(3))).view(-1, x.size(3))
        kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar,
                                           pb_z1_b1_mu, pb_z1_b1_logvar)
        # kl_div_option = ops.kl_div_gaussian(o_mean, o_logvar)
        kl_div_option = 0.5 * torch.sum(o_mean**2 + o_logvar.exp() - o_logvar -
                                        1)

        kl_shift_qb_pt = (
            ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) -
            ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar, qb_z2_b2))

        pd_x2_z2 = pd_x2_z2.flatten(1, -1)
        bce = loss(pd_x2_z2, x2, reduction='none').sum(dim=1)
        bce_optimal = loss(x2, x2, reduction='none').sum(dim=1)
        bce_diff = bce - bce_optimal

        if self.adversarial and self.is_training():
            r_in = x2.view(x2.shape[0], x.shape[2], x.shape[3], x.shape[4])
            f_in = pd_x2_z2.view(x2.shape[0], x.shape[2], x.shape[3],
                                 x.shape[4])
            for _ in range(self.d_steps):
                d_loss, g_loss, hidden_loss = self.dnet.get_loss(r_in, f_in)
                d_loss.backward(retain_graph=True)

                self.adversarial_optim.step()
                self.adversarial_optim.zero_grad()
            bce_diff = hidden_loss  # XXX bce_diff added twice to loss?
        else:
            g_loss = 0
            hidden_loss = 0

        if self.model_based:
            # pred_z2, pred_g = self.model.predict_forward(pb_z1_b1, options, t_encodings)
            # with torch.no_grad():
            #     # size: bs, action_space
            #     t1_next = t1 + 1
            #     t2_next = t2 + 1
            #     _, pred_values = self.target_net.q_and_z_b(x_orig, actions, rewards, done, t1_next,
            #                                                t2_next)[:2]
            #
            # target_q2 = r2_next + self.flags.discount_factor * (1.0 - done2_next) * q2_next

            # Note: x[t], rewards[t] is a result of actions[t]
            # Q(s[t], a[t+1]) = r[t+1] + γ max_a Q(s[t+1], a)
            returns, is_weight = labels

            # use pd_g2_z2_mu for returns modeling
            returns_loss = (pd_g2_z2_mu.squeeze(1) - (10.0 * returns))**2

            # XXX reward clipping hardcoded for Seaquest
            clipped_rewards = (rewards / 10.0).clamp(0.0, 2.0)

            t1_next = t1 + 1
            t2_next = t2 + 1

            with torch.no_grad():
                # size: bs, action_space
                q1_next_target, q2_next_target = self.target_net.q_and_z_b(
                    x_orig, actions, rewards, done, t1_next, t2_next)[:2]
                q1_next_index, q2_next_index = self.model.q_and_z_b(
                    x_orig, actions, rewards, done, t1_next, t2_next)[:2]
                q1_next_index = torch.argmax(q1_next_index,
                                             dim=1,
                                             keepdim=True)
                q2_next_index = torch.argmax(q2_next_index,
                                             dim=1,
                                             keepdim=True)

            done = done[None, ...].expand(self.flags.samples_per_seq, -1,
                                          -1)  # size: copy, bs, time
            done1_next = torch.gather(done, 2,
                                      t1_next[..., None]).view(-1)  # size: bs
            done2_next = torch.gather(done, 2,
                                      t2_next[..., None]).view(-1)  # size: bs

            # size: copy, bs, time
            clipped_rewards = clipped_rewards[None, ...].expand(
                self.flags.samples_per_seq, -1, -1)
            r1_next = torch.gather(clipped_rewards, 2,
                                   t1_next[..., None]).view(-1)  # size: bs
            r2_next = torch.gather(clipped_rewards, 2,
                                   t2_next[..., None]).view(-1)  # size: bs

            # actions = actions[None, ...].expand(self.flags.samples_per_seq, -1, -1)  # size: copy, bs, time
            # a1_next = torch.gather(actions, 2, t1_next[..., None]).view(-1)  # size: bs
            # a2_next = torch.gather(actions, 2, t2_next[..., None]).view(-1)  # size: bs
            #
            # pred_q1 = torch.gather(q1, 1, a1_next[..., None]).view(-1)
            # pred_q2 = torch.gather(q2, 1, a2_next[..., None]).view(-1)

            q1 = q1.squeeze(-1)
            q2 = q2.squeeze(-1)
            q1_next = torch.gather(q1_next_target, 1, q1_next_index).view(-1)
            q2_next = torch.gather(q2_next_target, 1, q2_next_index).view(-1)
            target_q1 = r1_next + self.flags.discount_factor * (
                1.0 - done1_next) * q1_next
            target_q2 = r2_next + self.flags.discount_factor * (
                1.0 - done2_next) * q2_next
            rl_loss = 0.5 * (
                F.smooth_l1_loss(q1, target_q1, reduction='none') +
                F.smooth_l1_loss(q2, target_q2, reduction='none'))
            # errors for prioritized experience replay
            rl_errors = 0.5 * (torch.abs(q1 - target_q1) +
                               torch.abs(q2 - target_q2)).detach()

        else:
            returns_loss = 0.0
            rl_loss = 0.0
            is_weight = 1.0
            rl_errors = 0.0

        # multiply is_weight separately for ease of reporting
        is_weight = is_weight.float()
        returns_loss = is_weight * returns_loss
        bce_optimal = is_weight * bce_optimal
        bce_diff = is_weight * bce_diff
        hidden_loss = is_weight * hidden_loss
        g_loss = is_weight * g_loss
        kl_div_qs_pb = is_weight * kl_div_qs_pb
        kl_shift_qb_pt = is_weight * kl_shift_qb_pt
        rl_loss = is_weight * rl_loss

        beta = self.beta_decay.get_y(self.get_train_steps())
        tdvae_loss = bce_diff + returns_loss + hidden_loss + self.d_weight * g_loss + beta * (
            kl_div_qs_pb + kl_shift_qb_pt)
        option_loss = option_recon_loss + beta * kl_div_option * 0.001
        loss = self.flags.tdvae_weight * tdvae_loss + self.flags.rl_weight * rl_loss + option_loss

        if self.rl:  # workaround to work with non-RL setting
            rl_loss = rl_loss.mean()
            returns_loss = returns_loss.mean()
        return collections.OrderedDict([
            ('loss', loss.mean()), ('bce_diff', bce_diff.mean()),
            ('returns_loss', returns_loss),
            ('kl_div_qs_pb', kl_div_qs_pb.mean()),
            ('kl_shift_qb_pt', kl_shift_qb_pt.mean()),
            ('kl_div_option', kl_div_option.mean()),
            ('reconstruction_option', option_recon_loss.mean()),
            ('rl_loss', rl_loss), ('bce_optimal', bce_optimal.mean()),
            ('rl_errors', rl_errors)
        ])
Ejemplo n.º 5
0
    def loss_function(self, forward_ret, labels=None):
        (x_orig, actions, rewards, done, t1, t2, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu,
         pb_z1_b1_logvar, qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2, pd_g2_z2_mu,
         q1, q2) = forward_ret

        # replicate x multiple times
        x = x_orig.flatten(3, -1)
        x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1, -1, -1)  # size: copy, bs, time, dim
        x2 = torch.gather(x, 2, t2[..., None, None, None].expand(-1, -1, -1, x.size(3), x.size(4)))
        x2 = x2.long().view(-1, x.size(3), x.size(4))
        kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar)

        kl_shift_qb_pt = (ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) -
                          ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar, qb_z2_b2))

        ce_1 = F.cross_entropy(pd_x2_z2[0], x2[:, 0])
        ce_2 = F.cross_entropy(pd_x2_z2[1], x2[:, 1])
        ce_3 = F.cross_entropy(pd_x2_z2[2], x2[:, 2])
        obs_ce = F.cross_entropy(pd_x2_z2[3], x2[:, 3])/(x_orig.shape[1])

        total_ce = ce_1 + ce_2 + ce_3 + obs_ce

        if self.adversarial and self.is_training():
            r_in = x2.view(x2.shape[0], x.shape[2], x.shape[3], x.shape[4])
            f_in = pd_x2_z2.view(x2.shape[0], x.shape[2], x.shape[3], x.shape[4])
            for _ in range(self.d_steps):
                d_loss, g_loss, hidden_loss = self.dnet.get_loss(r_in, f_in)
                d_loss.backward(retain_graph=True)
                # print(d_loss, g_loss)
                self.adversarial_optim.step()
                self.adversarial_optim.zero_grad()
            bce_diff = hidden_loss  # XXX bce_diff added twice to loss?
        else:
            g_loss = 0
            hidden_loss = 0

        if self.rl:
            # Note: x[t], rewards[t] is a result of actions[t]
            # Q(s[t], a[t+1]) = r[t+1] + γ max_a Q(s[t+1], a)
            returns, is_weight = labels

            # use pd_g2_z2_mu for returns modeling
            returns_loss = (pd_g2_z2_mu.squeeze(1) - (10.0 * returns)) ** 2

            # reward clipping for Atari
            clipped_rewards = rewards.clamp(-1.0, 1.0)

            t1_next = t1 + 1
            t2_next = t2 + 1

            with torch.no_grad():
                # size: bs, action_space
                q1_next_target, q2_next_target = self.target_net.q_and_z_b(x_orig, actions, rewards, done, t1_next,
                                                                           t2_next)[:2]
                q1_next_index, q2_next_index = self.model.q_and_z_b(x_orig, actions, rewards, done, t1_next,
                                                                    t2_next)[:2]
                q1_next_index = torch.argmax(q1_next_index, dim=1, keepdim=True)
                q2_next_index = torch.argmax(q2_next_index, dim=1, keepdim=True)

            done = done[None, ...].expand(self.flags.samples_per_seq, -1, -1)  # size: copy, bs, time
            done1_next = torch.gather(done, 2, t1_next[..., None]).view(-1)  # size: bs
            done2_next = torch.gather(done, 2, t2_next[..., None]).view(-1)  # size: bs

            # size: copy, bs, time
            clipped_rewards = clipped_rewards[None, ...].expand(self.flags.samples_per_seq, -1, -1)
            r1_next = torch.gather(clipped_rewards, 2, t1_next[..., None]).view(-1)  # size: bs
            r2_next = torch.gather(clipped_rewards, 2, t2_next[..., None]).view(-1)  # size: bs

            actions = actions[None, ...].expand(self.flags.samples_per_seq, -1, -1)  # size: copy, bs, time
            a1_next = torch.gather(actions, 2, t1_next[..., None]).view(-1)  # size: bs
            a2_next = torch.gather(actions, 2, t2_next[..., None]).view(-1)  # size: bs

            pred_q1 = torch.gather(q1, 1, a1_next[..., None]).view(-1)
            pred_q2 = torch.gather(q2, 1, a2_next[..., None]).view(-1)

            q1_next = torch.gather(q1_next_target, 1, q1_next_index).view(-1)
            q2_next = torch.gather(q2_next_target, 1, q2_next_index).view(-1)
            target_q1 = r1_next + self.flags.discount_factor * (1.0 - done1_next) * q1_next
            target_q2 = r2_next + self.flags.discount_factor * (1.0 - done2_next) * q2_next

            rl_loss = 0.5 * (F.smooth_l1_loss(pred_q1, target_q1, reduction='none') +
                             F.smooth_l1_loss(pred_q2, target_q2, reduction='none'))
            # errors for prioritized experience replay
            rl_errors = 0.5 * (torch.abs(pred_q1 - target_q1) + torch.abs(pred_q2 - target_q2)).detach()
        else:
            returns_loss = 0.0
            rl_loss = 0.0
            is_weight = 1.0
            rl_errors = 0.0

        # multiply is_weight separately for ease of reporting
        returns_loss = is_weight * returns_loss
        total_ce = is_weight * total_ce
        hidden_loss = is_weight * hidden_loss
        g_loss = is_weight * g_loss
        kl_div_qs_pb = is_weight * kl_div_qs_pb
        kl_shift_qb_pt = is_weight * kl_shift_qb_pt
        rl_loss = is_weight * rl_loss

        beta = self.beta_decay.get_y(self.get_train_steps())
        tdvae_loss = total_ce + returns_loss + hidden_loss + self.d_weight * g_loss + beta * (kl_div_qs_pb +
                                                                                              kl_shift_qb_pt)
        loss = self.flags.tdvae_weight * tdvae_loss + self.flags.rl_weight * rl_loss

        if self.rl:  # workaround to work with non-RL setting
            rl_loss = rl_loss.mean()
            returns_loss = returns_loss.mean()
        return collections.OrderedDict([('loss', loss.mean()),
                                        ('total_ce', total_ce.mean()),
                                        ('returns_loss', returns_loss),
                                        ('kl_div_qs_pb', kl_div_qs_pb.mean()),
                                        ('kl_shift_qb_pt', kl_shift_qb_pt.mean()),
                                        ('rl_loss', rl_loss),
                                        # ('bce_optimal', bce_optimal.mean()),
                                        ('rl_errors', rl_errors)])