예제 #1
0
def vcrnetIter(net, src, tgt, iter=1):
    transformed_src = src
    bFirst = True

    for i in range(iter):
        srcK, src_corrK, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(
            transformed_src, tgt)
        transformed_src = transform_point_cloud(transformed_src,
                                                rotation_ab_pred,
                                                translation_ab_pred)

        if bFirst:
            bFirst = False
            rotation_ab_pred_final = rotation_ab_pred.detach()
            translation_ab_pred_final = translation_ab_pred.detach()
        else:
            rotation_ab_pred_final = torch.matmul(rotation_ab_pred.detach(),
                                                  rotation_ab_pred_final)
            translation_ab_pred_final = torch.matmul(
                rotation_ab_pred.detach(),
                translation_ab_pred_final.unsqueeze(2)).squeeze(
                    2) + translation_ab_pred.detach()

    rotation_ba_pred_final = rotation_ab_pred_final.transpose(2,
                                                              1).contiguous()
    translation_ba_pred_final = -torch.matmul(
        rotation_ba_pred_final,
        translation_ab_pred_final.unsqueeze(2)).squeeze(2)

    return srcK, src_corrK, rotation_ab_pred_final, translation_ab_pred_final, rotation_ba_pred_final, translation_ba_pred_final
예제 #2
0
    def forward(self, srcInit, dst):
        icp_start = time()
        src = srcInit
        prev_error = 0
        for i in range(self.max_iterations):
            # find the nearest neighbors between the current source and destination points
            mean_error, src_corr = self.nearest_neighbor(src, dst)
            # compute the transformation between the current source and nearest destination points
            rotation_ab, translation_ab = self.best_fit_transform(
                src, src_corr)
            src = transform_point_cloud(src, rotation_ab, translation_ab)

            if torch.abs(prev_error - mean_error) < self.tolerance:
                break
            prev_error = mean_error

        # calculate final transformation
        rotation_ab, translation_ab = self.best_fit_transform(srcInit, src)

        rotation_ba = rotation_ab.transpose(2, 1).contiguous()
        translation_ba = -torch.matmul(rotation_ba,
                                       translation_ab.unsqueeze(2)).squeeze(2)

        print("icp: ", time() - icp_start)
        return srcInit, src, rotation_ab, translation_ab, rotation_ba, translation_ba
예제 #3
0
def showBad(args, net, test_loader, idTopKRot, idTopKTrans):
    net.eval()
    idTopKRotPair = np.zeros((idTopKRot.shape[0], 2), dtype=np.int32)
    idTopKTransPair = np.zeros((idTopKTrans.shape[0], 2), dtype=np.int32)
    for i in range(idTopKRot.shape[0]):
        idTopKRotPair[i, 0] = idTopKRot[i] // args.test_batch_size
        idTopKRotPair[i, 1] = idTopKRot[i] % args.test_batch_size
        idTopKTransPair[i, 0] = idTopKTrans[i] // args.test_batch_size
        idTopKTransPair[i, 1] = idTopKTrans[i] % args.test_batch_size

    batch_id = -1
    with torch.no_grad():
        for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba, label in tqdm(
                test_loader):
            src = src.cuda()
            target = target.cuda()
            # rotation_ab = rotation_ab.cuda()
            # translation_ab = translation_ab.cuda()
            batch_id = batch_id + 1
            if (batch_id in idTopKRotPair[:, 0]) or (batch_id in idTopKTransPair[:, 0]):

                rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(src, target)

                transformed_src = transform_point_cloud(src, rotation_ab_pred, translation_ab_pred)

                idR = idTopKRotPair[idTopKRotPair[:, 0] == batch_id, 1]
                idt = idTopKTransPair[idTopKTransPair[:, 0] == batch_id, 1]

                rotation_ab_pred = rotation_ab_pred.detach().cpu().numpy()
                translation_ab_pred = translation_ab_pred.detach().cpu().numpy()
                euler_ab = euler_ab.detach().cpu().numpy()
                translation_ab = translation_ab.detach().cpu().numpy()

                for id_R in idR:
                    r = Rotation.from_dcm(rotation_ab_pred[id_R, :, :])
                    eulers_ab_pred = r.as_euler('zyx', degrees=True)
                    euler_ab_i = euler_ab[id_R, :]
                    euler_delta = eulers_ab_pred - np.degrees(euler_ab_i)
                    mae = np.mean(np.abs(euler_delta))
                    title = 'euler error: ' + str(euler_delta) + ' mae: ' + str(mae)
                    filename = str(batch_id * args.test_batch_size + id_R) + '-R'
                    # savePlt(args, target[id_R,:,:], transformed_src[id_R,:,:], folderName='bad_cp', title=title,filename=filename)

                for id_t in idt:
                    translation_ab_pred_i = translation_ab_pred[id_t, :]
                    translation_ab_i = translation_ab[id_t, :]
                    translation_delta = translation_ab_pred_i - translation_ab_i
                    mae = np.mean(np.abs(translation_delta))
                    title = 'trans error: ' + str(translation_delta) + ' mae: ' + str(mae)
                    filename = str(batch_id * args.test_batch_size + id_t) + '-T'
예제 #4
0
def vcrnetIcpNet(args, net, src, tgt):
    icpNet = ICP(max_iterations=args.max_iterations).cuda()
    srcK, src_corrK, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(
        src, tgt)

    transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                            translation_ab_pred)

    _, _, rotation_ab_pred_icp, translation_ab_pred_icp, rotation_ba_pred_icp, translation_ba_pred_icp = icpNet(
        transformed_src, tgt)

    rotation_ab_pred = torch.matmul(rotation_ab_pred_icp, rotation_ab_pred)
    translation_ab_pred = torch.matmul(
        rotation_ab_pred_icp,
        translation_ab_pred.unsqueeze(2)).squeeze(2) + translation_ab_pred_icp

    rotation_ba_pred = rotation_ab_pred.transpose(2, 1).contiguous()
    translation_ba_pred = -torch.matmul(
        rotation_ba_pred, translation_ab_pred.unsqueeze(2)).squeeze(2)

    return transformed_src, tgt, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred
예제 #5
0
def train_one_epoch(args, net, train_loader, opt):
    net.train()

    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss_VCRNet = 0

    total_loss = 0
    total_cycle_loss = 0
    num_examples = 0
    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba, label in tqdm(
            train_loader):
        src = src.cuda()
        target = target.cuda()
        rotation_ab = rotation_ab.cuda()
        translation_ab = translation_ab.cuda()
        rotation_ba = rotation_ba.cuda()
        translation_ba = translation_ba.cuda()

        batch_size = src.size(0)
        opt.zero_grad()
        num_examples += batch_size
        srcK, src_corrK, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(
            src, target)

        ## save rotation and translation
        rotations_ab.append(rotation_ab.detach().cpu().numpy())
        translations_ab.append(translation_ab.detach().cpu().numpy())
        rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
        translations_ab_pred.append(translation_ab_pred.detach().cpu().numpy())
        eulers_ab.append(euler_ab.numpy())
        ##
        rotations_ba.append(rotation_ba.detach().cpu().numpy())
        translations_ba.append(translation_ba.detach().cpu().numpy())
        rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
        translations_ba_pred.append(translation_ba_pred.detach().cpu().numpy())
        eulers_ba.append(euler_ba.numpy())

        transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                translation_ab_pred)
        transformed_target = transform_point_cloud(target, rotation_ba_pred,
                                                   translation_ba_pred)

        transformed_srcK = transform_point_cloud(srcK, rotation_ab,
                                                 translation_ab)
        ###########################
        identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
        if args.loss == 'pose':
            loss_VCRNet = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                          + F.mse_loss(translation_ab_pred, translation_ab)
        elif args.loss == 'point':
            loss_VCRNet = torch.nn.functional.mse_loss(transformed_srcK,
                                                       src_corrK)
        else:
            lossPose = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                       + F.mse_loss(translation_ab_pred, translation_ab)
            lossPoint = torch.nn.functional.mse_loss(transformed_src, target)
            loss_VCRNet = lossPose + 0.1 * lossPoint

        loss_VCRNet.backward()
        total_loss_VCRNet += loss_VCRNet.item() * batch_size

        loss_pose = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                    + F.mse_loss(translation_ab_pred, translation_ab)
        if args.cycle:
            rotation_loss = F.mse_loss(
                torch.matmul(rotation_ba_pred, rotation_ab_pred),
                identity.clone())
            translation_loss = torch.mean(
                (torch.matmul(rotation_ba_pred.transpose(2, 1),
                              translation_ab_pred.view(batch_size, 3, 1)).view(
                                  batch_size, 3) + translation_ba_pred)**2,
                dim=[0, 1])
            cycle_loss = rotation_loss + translation_loss

            loss_pose = loss_pose + cycle_loss * 0.1

        opt.step()
        total_loss += loss_pose.item() * batch_size

        if args.cycle:
            total_cycle_loss = total_cycle_loss + cycle_loss.item(
            ) * 0.1 * batch_size

        mse_ab += torch.mean((transformed_srcK - src_corrK)**2,
                             dim=[0, 1, 2]).item() * batch_size
        mae_ab += torch.mean(torch.abs(transformed_srcK - src_corrK),
                             dim=[0, 1, 2]).item() * batch_size

        mse_ba += torch.mean(
            (transformed_target - src)**2, dim=[0, 1, 2]).item() * batch_size
        mae_ba += torch.mean(torch.abs(transformed_target - src),
                             dim=[0, 1, 2]).item() * batch_size

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
           mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
           mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
           translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
           translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba, total_loss_VCRNet * 1.0 / num_examples
예제 #6
0
def test_one_epoch(args, net, test_loader):
    net.eval()
    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss_VCRNet = 0

    total_loss = 0
    total_cycle_loss = 0
    num_examples = 0
    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    with torch.no_grad():
        for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba, label in tqdm(
                test_loader):

            src = src.cuda()
            target = target.cuda()
            rotation_ab = rotation_ab.cuda()
            translation_ab = translation_ab.cuda()
            rotation_ba = rotation_ba.cuda()
            translation_ba = translation_ba.cuda()

            batch_size = src.size(0)
            num_examples += batch_size

            if args.iter > 0:
                srcK, src_corrK, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = vcrnetIter(
                    net, src, target, iter=args.iter)
            elif args.iter == 0:
                srcK, src_corrK, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = vcrnetIcpNet(
                    args, net, src, target)
            else:
                raise RuntimeError('args.iter')

            ## save rotation and translation
            rotations_ab.append(rotation_ab.detach().cpu().numpy())
            translations_ab.append(translation_ab.detach().cpu().numpy())
            rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
            translations_ab_pred.append(
                translation_ab_pred.detach().cpu().numpy())
            eulers_ab.append(euler_ab.numpy())
            ##
            rotations_ba.append(rotation_ba.detach().cpu().numpy())
            translations_ba.append(translation_ba.detach().cpu().numpy())
            rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
            translations_ba_pred.append(
                translation_ba_pred.detach().cpu().numpy())
            eulers_ba.append(euler_ba.numpy())

            # Predicted point cloud
            transformed_target = transform_point_cloud(target,
                                                       rotation_ba_pred,
                                                       translation_ba_pred)
            # Real point cloud
            transformed_srcK = transform_point_cloud(srcK, rotation_ab,
                                                     translation_ab)

            # transformed_src = transform_point_cloud(src, rotation_ab_pred, translation_ab_pred)
            # from PC_reg_gif.draw import plot3d2
            # plot3d2(transformed_src[1], target[1])
            # plot3d2(transformed_target[1], src[1])

            ###########################
            identity = torch.eye(3).cuda().unsqueeze(0).repeat(
                batch_size, 1, 1)
            if args.loss == 'pose':
                loss_VCRNet = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                              + F.mse_loss(translation_ab_pred, translation_ab)
            elif args.loss == 'point':
                loss_VCRNet = torch.nn.functional.mse_loss(
                    transformed_srcK, src_corrK)
            else:
                lossPose = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                           + F.mse_loss(translation_ab_pred, translation_ab)
                transformed_src = transform_point_cloud(
                    src, rotation_ab_pred, translation_ab_pred)
                lossPoint = torch.nn.functional.mse_loss(
                    transformed_src, target)
                loss_VCRNet = lossPose + 0.1 * lossPoint

            loss_pose = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                        + F.mse_loss(translation_ab_pred, translation_ab)

            total_loss_VCRNet += loss_VCRNet.item() * batch_size

            if args.cycle:
                rotation_loss = F.mse_loss(
                    torch.matmul(rotation_ba_pred, rotation_ab_pred),
                    identity.clone())
                translation_loss = torch.mean(
                    (torch.matmul(rotation_ba_pred.transpose(2, 1),
                                  translation_ab_pred.view(
                                      batch_size, 3, 1)).view(batch_size, 3) +
                     translation_ba_pred)**2,
                    dim=[0, 1])
                cycle_loss = rotation_loss + translation_loss

                loss_pose = loss_pose + cycle_loss * 0.1

            total_loss += loss_pose.item() * batch_size

            if args.cycle:
                total_cycle_loss = total_cycle_loss + cycle_loss.item(
                ) * 0.1 * batch_size

            mse_ab += torch.mean((transformed_srcK - src_corrK)**2,
                                 dim=[0, 1, 2]).item() * batch_size
            mae_ab += torch.mean(torch.abs(transformed_srcK - src_corrK),
                                 dim=[0, 1, 2]).item() * batch_size

            mse_ba += torch.mean((transformed_target - src)**2,
                                 dim=[0, 1, 2]).item() * batch_size
            mae_ba += torch.mean(torch.abs(transformed_target - src),
                                 dim=[0, 1, 2]).item() * batch_size

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
           mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
           mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
           translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
           translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba, total_loss_VCRNet * 1.0 / num_examples
예제 #7
0
def test_one_epoch(args, net, test_loader):
    net.eval()
    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss = 0
    total_cycle_loss = 0
    num_examples = 0
    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    with torch.no_grad():
        for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba, label in tqdm(
                test_loader):
            src = src.cuda()
            target = target.cuda()
            rotation_ab = rotation_ab.cuda()
            translation_ab = translation_ab.cuda()
            rotation_ba = rotation_ba.cuda()
            translation_ba = translation_ba.cuda()

            batch_size = src.size(0)
            num_examples += batch_size
            srcInit, src, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(
                src, target)

            if args.use_mFea:
                src = src.transpose(2, 1).split([3, 5],
                                                dim=2)[0].transpose(2, 1)
                target = target.transpose(2,
                                          1).split([3, 5],
                                                   dim=2)[0].transpose(2, 1)

            ## save rotation and translation
            rotations_ab.append(rotation_ab.detach().cpu().numpy())
            translations_ab.append(translation_ab.detach().cpu().numpy())
            rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
            translations_ab_pred.append(
                translation_ab_pred.detach().cpu().numpy())
            eulers_ab.append(euler_ab.numpy())

            rotations_ba.append(rotation_ba.detach().cpu().numpy())
            translations_ba.append(translation_ba.detach().cpu().numpy())
            rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
            translations_ba_pred.append(
                translation_ba_pred.detach().cpu().numpy())
            eulers_ba.append(euler_ba.numpy())

            transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                    translation_ab_pred)

            transformed_target = transform_point_cloud(target,
                                                       rotation_ba_pred,
                                                       translation_ba_pred)

            ###########################
            identity = torch.eye(3).cuda().unsqueeze(0).repeat(
                batch_size, 1, 1)
            if args.loss == 'pose':
                loss = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                       + F.mse_loss(translation_ab_pred, translation_ab)
            elif args.loss == 'point':
                loss = torch.mean((transformed_src - target)**2, dim=[0, 1, 2])
            else:
                lossPose = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                           + F.mse_loss(translation_ab_pred, translation_ab)
                lossPoint = torch.mean((transformed_src - target)**2,
                                       dim=[0, 1, 2])
                loss = lossPose + 0.1 * lossPoint
            if args.cycle:
                rotation_loss = F.mse_loss(
                    torch.matmul(rotation_ba_pred, rotation_ab_pred),
                    identity.clone())
                translation_loss = torch.mean(
                    (torch.matmul(rotation_ba_pred.transpose(2, 1),
                                  translation_ab_pred.view(
                                      batch_size, 3, 1)).view(batch_size, 3) +
                     translation_ba_pred)**2,
                    dim=[0, 1])
                cycle_loss = rotation_loss + translation_loss

                loss = loss + cycle_loss * 0.1

            total_loss += loss.item() * batch_size

            if args.cycle:
                total_cycle_loss = total_cycle_loss + cycle_loss.item(
                ) * 0.1 * batch_size

            mse_ab += torch.mean((transformed_src - target)**2,
                                 dim=[0, 1, 2]).item() * batch_size
            mae_ab += torch.mean(torch.abs(transformed_src - target),
                                 dim=[0, 1, 2]).item() * batch_size

            mse_ba += torch.mean((transformed_target - src)**2,
                                 dim=[0, 1, 2]).item() * batch_size
            mae_ba += torch.mean(torch.abs(transformed_target - src),
                                 dim=[0, 1, 2]).item() * batch_size

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
           mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
           mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
           translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
           translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba