Exemple #1
0
    def interpolate(self, x_a, x_b, N):

        step_size = 1. / (N - 1)
        batch_size, _, seq_len = x_a.size()

        motion_a = self.encode_motion(x_a)
        body_a, body_a_seq = self.encode_body(x_a)
        view_a, view_a_seq = self.encode_view(x_a)

        motion_b = self.encode_motion(x_b)
        body_b, body_b_seq = self.encode_body(x_b)
        view_b, view_b_seq = self.encode_view(x_b)

        batch_out = torch.zeros([batch_size, N, N, 2 * self.n_joints, seq_len])

        for i in range(N):
            motion_weight = i * step_size
            for j in range(N):
                body_weight = j * step_size
                motion = (1. -
                          motion_weight) * motion_a + motion_weight * motion_b
                body = (1. - body_weight) * body_a + body_weight * body_b
                view = (1. - body_weight) * view_a + body_weight * view_b
                out = self.decode(motion, body, view)
                out = rotate_and_maybe_project(
                    out, body_reference=self.body_reference, project_2d=True)
                batch_out[:, i, j, :, :] = out

        return batch_out
Exemple #2
0
 def reconstruct2d(self, x):
     motion_code = self.encode_motion(x)
     body_code, _ = self.encode_body(x)
     view_code, _ = self.encode_view(x)
     out = self.decode(motion_code, body_code, view_code)
     out = rotate_and_maybe_project(out,
                                    body_reference=self.body_reference,
                                    project_2d=True)
     return out
Exemple #3
0
 def cross2d(self, x_a, x_b, x_c):
     motion_a = self.encode_motion(x_a)
     body_b, _ = self.encode_body(x_b)
     view_c, _ = self.encode_view(x_c)
     out = self.decode(motion_a, body_b, view_c)
     out = rotate_and_maybe_project(out,
                                    body_reference=self.body_reference,
                                    project_2d=True)
     return out
Exemple #4
0
    def dis_update(self, data, config):

        x_a = data["x"].detach()
        x_s = data["x_s"].detach() # the limb-scaled version of x_a

        self.dis_opt.zero_grad()

        # encode
        motion_a = self.autoencoder.encode_motion(x_a)
        body_a, body_a_seq = self.autoencoder.encode_body(x_a)
        view_a, view_a_seq = self.autoencoder.encode_view(x_a)

        motion_s = self.autoencoder.encode_motion(x_s)
        body_s, body_s_seq = self.autoencoder.encode_body(x_s)
        view_s, view_s_seq = self.autoencoder.encode_view(x_s)

        # decode (reconstruct, transform)
        inds = random.sample(list(range(self.angles.size(0))), config.K)
        angles = self.angles[inds].clone().detach()  # [K, 3]
        angles += self.angle_unit * self.rotation_axes * torch.randn([3], device=x_a.device)
        angles = angles.unsqueeze(0).unsqueeze(2)  # [B=1, K, T=1, 3]

        X_a_recon = self.autoencoder.decode(motion_a, body_a, view_a)
        x_a_trans = rotate_and_maybe_project(X_a_recon, angles=angles, body_reference=config.autoencoder.body_reference,
                                             project_2d=True)

        x_a_exp = x_a.repeat_interleave(config.K, dim=0)

        self.loss_dis_trans = self.discriminator.calc_dis_loss(x_a_trans.detach(), x_a_exp)

        if config.trans_gan_ls_w > 0:
            X_s_recon = self.autoencoder.decode(motion_s, body_s, view_s)
            x_s_trans = rotate_and_maybe_project(X_s_recon, angles=angles,
                                                 body_reference=config.autoencoder.body_reference, project_2d=True)
            x_s_exp = x_s.repeat_interleave(config.K, dim=0)
            self.loss_dis_trans_ls = self.discriminator.calc_dis_loss(x_s_trans.detach(), x_s_exp)
        else:
            self.loss_dis_trans_ls = 0

        self.loss_dis_total = config.trans_gan_w * self.loss_dis_trans + \
                              config.trans_gan_ls_w * self.loss_dis_trans_ls

        self.loss_dis_total.backward()
        self.dis_opt.step()
Exemple #5
0
    def ae_update(self, data, config):

        x_a = data["x"].detach()
        x_s = data["x_s"].detach()
        self.ae_opt.zero_grad()

        # encode
        motion_a = self.autoencoder.encode_motion(x_a)
        body_a, body_a_seq = self.autoencoder.encode_body(x_a)
        view_a, view_a_seq = self.autoencoder.encode_view(x_a)

        motion_s = self.autoencoder.encode_motion(x_s)
        body_s, body_s_seq = self.autoencoder.encode_body(x_s)
        view_s, view_s_seq = self.autoencoder.encode_view(x_s)

        # invariance loss
        self.loss_inv_v_ls = self.recon_criterion(view_a, view_s) if config.inv_v_ls_w > 0 else 0
        self.loss_inv_m_ls = self.recon_criterion(motion_a, motion_s) if config.inv_m_ls_w > 0 else 0

        # body triplet loss
        if config.triplet_b_w > 0:
            self.loss_triplet_b = triplet_margin_loss(
                body_a_seq, body_s_seq,
                neg_range=config.triplet_neg_range,
                margin=config.triplet_margin)
        else:
            self.loss_triplet_b = 0

        # reconstruction
        X_a_recon = self.autoencoder.decode(motion_a, body_a, view_a)
        x_a_recon = rotate_and_maybe_project(X_a_recon, angles=None, body_reference=config.autoencoder.body_reference, project_2d=True)

        X_s_recon = self.autoencoder.decode(motion_s, body_s, view_s)
        x_s_recon = rotate_and_maybe_project(X_s_recon, angles=None, body_reference=config.autoencoder.body_reference, project_2d=True)

        self.loss_recon_x = 0.5 * self.recon_criterion(x_a_recon, x_a) +\
                               0.5 * self.recon_criterion(x_s_recon, x_s)

        # cross reconstruction
        X_as_recon = self.autoencoder.decode(motion_a, body_s, view_s)
        x_as_recon = rotate_and_maybe_project(X_as_recon, angles=None, body_reference=config.autoencoder.body_reference, project_2d=True)

        X_sa_recon = self.autoencoder.decode(motion_s, body_a, view_a)
        x_sa_recon = rotate_and_maybe_project(X_sa_recon, angles=None, body_reference=config.autoencoder.body_reference, project_2d=True)

        self.loss_cross_x = 0.5 * self.recon_criterion(x_as_recon, x_s) + 0.5 * self.recon_criterion(x_sa_recon, x_a)

        # apply transformation
        inds = random.sample(list(range(self.angles.size(0))), config.K)
        angles = self.angles[inds].clone().detach()
        angles += self.angle_unit * self.rotation_axes * torch.randn([3], device=x_a.device)
        angles = angles.unsqueeze(0).unsqueeze(2)

        x_a_trans = rotate_and_maybe_project(X_a_recon, angles=angles, body_reference=config.autoencoder.body_reference, project_2d=True)
        x_s_trans = rotate_and_maybe_project(X_s_recon, angles=angles, body_reference=config.autoencoder.body_reference, project_2d=True)

        # GAN loss
        self.loss_gan_trans = self.discriminator.calc_gen_loss(x_a_trans)
        self.loss_gan_trans_ls = self.discriminator.calc_gen_loss(x_s_trans) if config.trans_gan_ls_w > 0 else 0

        # encode again
        motion_a_trans = self.autoencoder.encode_motion(x_a_trans)
        body_a_trans, _ = self.autoencoder.encode_body(x_a_trans)
        view_a_trans, view_a_trans_seq = self.autoencoder.encode_view(x_a_trans)

        motion_s_trans = self.autoencoder.encode_motion(x_s_trans)
        body_s_trans, _ = self.autoencoder.encode_body(x_s_trans)

        self.loss_inv_m_trans = 0.5 * self.recon_criterion(motion_a_trans, motion_a.repeat_interleave(config.K, dim=0)) + \
                                     0.5 * self.recon_criterion(motion_s_trans, motion_s.repeat_interleave(config.K, dim=0))
        self.loss_inv_b_trans = 0.5 * self.recon_criterion(body_a_trans, body_a.repeat_interleave(config.K, dim=0)) + \
                                     0.5 * self.recon_criterion(body_s_trans, body_s.repeat_interleave(config.K, dim=0))

        # view triplet loss
        if config.triplet_v_w > 0:
            view_a_seq_exp = view_a_seq.repeat_interleave(config.K, dim=0)
            self.loss_triplet_v = triplet_margin_loss(
                view_a_seq_exp, view_a_trans_seq,
                neg_range=config.triplet_neg_range, margin=config.triplet_margin)
        else:
            self.loss_triplet_v = 0

        # add all losses
        self.loss_total = torch.tensor(0.).float().cuda()
        self.loss_total += config.recon_x_w * self.loss_recon_x
        self.loss_total += config.cross_x_w * self.loss_cross_x
        self.loss_total += config.inv_v_ls_w * self.loss_inv_v_ls
        self.loss_total += config.inv_m_ls_w * self.loss_inv_m_ls
        self.loss_total += config.inv_b_trans_w * self.loss_inv_b_trans
        self.loss_total += config.inv_m_trans_w * self.loss_inv_m_trans
        self.loss_total += config.trans_gan_w * self.loss_gan_trans
        self.loss_total += config.trans_gan_ls_w * self.loss_gan_trans_ls
        self.loss_total += config.triplet_b_w * self.loss_triplet_b
        self.loss_total += config.triplet_v_w * self.loss_triplet_v

        self.loss_total.backward()
        self.ae_opt.step()