예제 #1
0
파일: model.py 프로젝트: mfkiwl/fmr
 def eval_1__write(self, fout, ig_gt, g_hat):
     x_hat = se3.log(g_hat)  # --> [-1, 6]
     mx_gt = se3.log(ig_gt)  # --> [-1, 6]
     for i in range(x_hat.size(0)):
         x_hat1 = x_hat[i]  # [6]
         mx_gt1 = mx_gt[i]  # [6]
         vals = torch.cat((x_hat1, -mx_gt1))  # [12]
         valn = vals.cpu().numpy().tolist()
         print(','.join(map(str, valn)), file=fout)
     fout.flush()
예제 #2
0
파일: model.py 프로젝트: mfkiwl/fmr
    def evaluate(self, solver, testloader, device):
        solver.eval()
        with open(self.filename, 'w') as fout:
            self.eval_1__header(fout)
            with torch.no_grad():
                for i, data in enumerate(testloader):
                    p0, p1, igt = data  # igt: p0->p1
                    # # compute trans from p1->p0
                    # g = se3.log(igt)  # --> [-1, 6]
                    # igt = se3.exp(-g)  # [-1, 4, 4]
                    p0, p1 = self.ablation_study(p0, p1)

                    p0 = p0.to(device)  # template (1, N, 3)
                    p1 = p1.to(device)  # source (1, M, 3)
                    solver.estimate_t(p0, p1, self.max_iter)

                    est_g = solver.g  # (1, 4, 4)

                    ig_gt = igt.cpu().contiguous().view(-1, 4,
                                                        4)  # --> [1, 4, 4]
                    g_hat = est_g.cpu().contiguous().view(-1, 4,
                                                          4)  # --> [1, 4, 4]

                    dg = g_hat.bmm(ig_gt)  # if correct, dg == identity matrix.
                    dx = se3.log(
                        dg)  # --> [1, 6] (if corerct, dx == zero vector)
                    dn = dx.norm(p=2, dim=1)  # --> [1]
                    dm = dn.mean()

                    self.eval_1__write(fout, ig_gt, g_hat)
                    print('test, %d/%d, %f' % (i, len(testloader), dm))
예제 #3
0
def main(args):
    # dataset
    testset = get_datasets(args)
    batch_size = len(testset)

    amp = args.deg * math.pi / 180.0
    w = torch.randn(batch_size, 3)
    w = w / w.norm(p=2, dim=1, keepdim=True) * amp
    t = torch.rand(batch_size, 3) * args.max_trans

    if args.format == 'wv':
        # the output: twist vectors.
        R = so3.exp(w)  # (N, 3) --> (N, 3, 3)
        G = torch.zeros(batch_size, 4, 4)
        G[:, 3, 3] = 1
        G[:, 0:3, 0:3] = R
        G[:, 0:3, 3] = t

        x = se3.log(G)  # --> (N, 6)
    else:
        # rotation-vector and translation-vector
        x = torch.cat((w, t), dim=1)

    numpy.savetxt(args.outfile, x, delimiter=',')