Beispiel #1
0
 def _train_one_batch(self, src, tgt, rotation_ab, translation_ab, opt, temp):
     opt.zero_grad()
     batch_size = src.size(0)
     identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(batch_size, 1, 1)
     rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
     translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)
     total_loss = 0
     temp = torch.tensor(temp).cuda().repeat(batch_size)
     for i in range(self.num_iters):
         rotation_ab_pred_i, translation_ab_pred_i, scores = self.forward(src, tgt, temp, i)
         # 残差位姿
         res_rotation_ab = torch.matmul(rotation_ab, rotation_ab_pred.transpose(2, 1))
         res_translation_ab = translation_ab - torch.matmul(res_rotation_ab,
                                                            translation_ab_pred.unsqueeze(2)).squeeze(2)
         # 累计位姿
         rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred)
         translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                               + translation_ab_pred_i
         # 熵值loss
         src_entropy_loss = self.compute_loss(scores, src, res_rotation_ab, res_translation_ab, tgt)
         res_rotation_ab_t = res_rotation_ab.transpose(2, 1).contiguous()
         res_translation_ab_t = - torch.matmul(res_rotation_ab_t, res_translation_ab.unsqueeze(2)).squeeze(2)
         tgt_entropy_loss = self.compute_loss(scores.transpose(2, 1).contiguous(), tgt, res_rotation_ab_t,
                                              res_translation_ab_t, src)
         entropy_loss = src_entropy_loss + tgt_entropy_loss
         pose_loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity)\
                      + F.mse_loss(translation_ab_pred, translation_ab))
         total_loss = total_loss + entropy_loss + pose_loss
         src = transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i)
     total_loss.backward()
     opt.step()
     return total_loss.item(), rotation_ab_pred, translation_ab_pred
Beispiel #2
0
 def forward(self, src, R_pred, t_pred, R_gt, t_gt, sigma_):
     batch_size = R_pred.size(0)
     sigma_ = sigma_.view(batch_size, 1)
     # (bs,)
     beta = (1 - torch.exp(-(0.5 * sigma_**(-2))))**(-1)
     src_pred = transform_point_cloud(src, R_pred, t_pred)
     src_gt = transform_point_cloud(src, R_gt, t_gt)
     # (bs, np)
     error = torch.sum((src_pred - src_gt)**2, dim=1)
     error = torch.sqrt(error)
     exp_error = torch.exp(-(error * 0.5 * sigma_**(-2)))
     one_tensor = exp_error.new_tensor(1).expand(batch_size, )
     mcc_loss = -torch.mean(exp_error, dim=-1) + one_tensor
     mcc_loss = mcc_loss * beta.squeeze(-1)
     mean_mcc_loss = torch.mean(mcc_loss, dim=0)
     return mean_mcc_loss
Beispiel #3
0
    def forward(self, srcInit, dst):
        icp_start = time()
        src = srcInit
        prev_error = 0
        for i in range(self.max_iterations):
            # find the nearest neighbors between the current source and destination points
            mean_error, src_corr = self.nearest_neighbor(src, dst)
            # compute the transformation between the current source and nearest destination points
            rotation_ab, translation_ab = self.best_fit_transform(
                src, src_corr)
            src = transform_point_cloud(src, rotation_ab, translation_ab)

            if torch.abs(prev_error - mean_error) < self.tolerance:
                # print('iteration: '+str(i))
                break
            prev_error = mean_error

        # calculate final transformation
        rotation_ab, translation_ab = self.best_fit_transform(srcInit, src)

        rotation_ba = rotation_ab.transpose(2, 1).contiguous()
        translation_ba = -torch.matmul(rotation_ba,
                                       translation_ab.unsqueeze(2)).squeeze(2)

        print("icp: ", time() - icp_start)
        return srcInit, src, rotation_ab, translation_ab, rotation_ba, translation_ba
Beispiel #4
0
    def _test_one_batch(self, src, tgt, rotation_ab, translation_ab, temp):
        batch_size = src.size(0)
        identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(
            batch_size, 1, 1)
        rotation_ab_pred = torch.eye(3, device=src.device,
                                     dtype=torch.float32).view(1, 3, 3).repeat(
                                         batch_size, 1, 1)
        translation_ab_pred = torch.zeros(3,
                                          device=src.device,
                                          dtype=torch.float32).view(
                                              1, 3).repeat(batch_size, 1)
        total_loss = 0
        temp = torch.tensor(temp).cuda().repeat(batch_size)
        for i in range(self.num_iters):
            rotation_ab_pred_i, translation_ab_pred_i = self.forward(
                src, tgt, temp, i)
            rotation_ab_pred = torch.matmul(rotation_ab_pred_i,
                                            rotation_ab_pred)
            translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                                  + translation_ab_pred_i

            loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                    + F.mse_loss(translation_ab_pred, translation_ab)) * self.discount_factor ** i
            total_loss = total_loss + loss
            src = transform_point_cloud(src, rotation_ab_pred_i,
                                        translation_ab_pred_i)
        return total_loss.item(), rotation_ab_pred, translation_ab_pred
Beispiel #5
0
 def _train_one_batch(self, src, tgt, rotation_ab, translation_ab, opt, temp, epoch):
     opt.zero_grad()
     batch_size = src.size(0)
     identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(batch_size, 1, 1)
     rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
     translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)
     total_loss = 0
     total_cycle_consistency_loss = 0
     temp = torch.tensor(temp).cuda().repeat(batch_size)
     # 求点集内点的密集程度,求其最近点的平均距离
     with torch.no_grad():
         distance = pairwise_distance(src, src)
         sort_distance, _ = torch.sort(distance, dim=-1)
         nearest_distance = sort_distance[:, :, 1].squeeze()
         median = torch.median(nearest_distance, dim=-1)[0]
         meanDist = torch.mean(median)
         sigmal_ = meanDist * self.sigma_times
         sigmal_ = sigmal_.repeat(batch_size)
         gamma_ = float(epoch / self.epochs)
     for i in range(self.num_iters):
         sigmal_ = sigmal_ * self.DecayPram
         rotation_ab_pred_i, translation_ab_pred_i = self.forward(src, tgt, temp, i, sigmal_)
         rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred)
         translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                               + translation_ab_pred_i
         mcc_loss = self.mcc_loss(src, rotation_ab_pred, translation_ab_pred, rotation_ab, translation_ab,
                                  sigmal_)
         # mse_loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
         #             + F.mse_loss(translation_ab_pred, translation_ab))
         # loss = mcc_loss * gamma_ + mse_loss * (1 - gamma_)
         total_loss = total_loss + mcc_loss * self.discount_factor ** i
         src = transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i)
     total_loss.backward()
     opt.step()
     return total_loss.item(), rotation_ab_pred, translation_ab_pred
Beispiel #6
0
 def _train_one_batch(self, src, tgt, rotation_ab, translation_ab, opt,
                      temp):
     opt.zero_grad()
     batch_size = src.size(0)
     identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(
         batch_size, 1, 1)
     rotation_ab_pred = torch.eye(3, device=src.device,
                                  dtype=torch.float32).view(1, 3, 3).repeat(
                                      batch_size, 1, 1)
     translation_ab_pred = torch.zeros(3,
                                       device=src.device,
                                       dtype=torch.float32).view(
                                           1, 3).repeat(batch_size, 1)
     total_loss = 0
     temp = torch.tensor(temp).cuda().repeat(batch_size)
     for i in range(self.num_iters):
         rotation_ab_pred_i, translation_ab_pred_i, sampling_scores = self.forward(
             src, tgt, temp, i)
         rotation_ab_pred = torch.matmul(rotation_ab_pred_i,
                                         rotation_ab_pred)
         translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                               + translation_ab_pred_i
         mse_loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                 + F.mse_loss(translation_ab_pred, translation_ab)) * self.discount_factor ** i
         gt_scores = compute_scores_gt(sampling_scores, src, tgt,
                                       rotation_ab, translation_ab)
         entropy_loss = self.entropy_loss(
             sampling_scores, gt_scores) * self.discount_factor**i
         loss = mse_loss * (1 - self.sigma_) + entropy_loss * self.sigma_
         total_loss = total_loss + loss
         src = transform_point_cloud(src, rotation_ab_pred_i,
                                     translation_ab_pred_i)
     total_loss.backward()
     opt.step()
     return total_loss.item(), rotation_ab_pred, translation_ab_pred
Beispiel #7
0
 def _test_one_batch(self, src, tgt, rotation_ab, translation_ab, temp, epoch):
     batch_size = src.size(0)
     identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(batch_size, 1, 1)
     rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
     translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)
     total_loss = 0
     temp = torch.tensor(temp).cuda().repeat(batch_size)
     # 求点集内点的密集程度,求其最近点的平均距离
     distance = pairwise_distance(src, src)
     sort_distance, _ = torch.sort(distance, dim=-1)
     nearest_distance = sort_distance[:, :, 1].squeeze(-1)
     meanDist = torch.median(nearest_distance, dim=-1)[0]
     # meanDist = torch.mean(median)
     sigmal_ = meanDist * self.sigma_times
     self.sigma_ = sigmal_.item()
     # sigmal_ = sigmal_.repeat(batch_size)
     gamma_ = epoch / self.epochs
     for i in range(self.num_iters):
         sigmal_ = sigmal_ * self.DecayPram
         rotation_ab_pred_i, translation_ab_pred_i, src_k, src_corr = self.forward(src, tgt, temp, i, sigmal_)
         rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred)
         translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                               + translation_ab_pred_i
         # mcc_loss = self.mcc_loss(src, rotation_ab_pred, translation_ab_pred, rotation_ab, translation_ab,
         #                      sigmal_)
         mcc_loss = self.mcc_loss(src_k, rotation_ab, translation_ab, src_corr, sigmal_)
         mse_loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                 + F.mse_loss(translation_ab_pred, translation_ab))
         loss = mcc_loss * gamma_ + mse_loss * (1 - gamma_)
         total_loss = total_loss + loss * self.discount_factor ** i
         src = transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i)
     return total_loss.item(), rotation_ab_pred, translation_ab_pred
def train_one_epoch(args, net, train_loader, opt):
    net.train()

    total_loss = 0
    num_examples = 0

    for src, target, gt_flow in tqdm(train_loader):
        src = src.cuda()
        target = target.cuda()

        batch_size = src.size(0)
        opt.zero_grad()
        num_examples += batch_size
        rotation_ab_pred, translation_ab_pred = net(src, target)

        ###########################
        transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                translation_ab_pred)
        loss = EPE(transformed_src, target)

        loss.backward()
        opt.step()
        total_loss += loss.item() * batch_size

    return total_loss * 1.0 / num_examples
Beispiel #9
0
 def _test_one_batch(self, src, tgt, rotation_ab, translation_ab, temp):
     batch_size = src.size(0)
     identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(
         batch_size, 1, 1)
     rotation_ab_pred = torch.eye(3, device=src.device,
                                  dtype=torch.float32).view(1, 3, 3).repeat(
                                      batch_size, 1, 1)
     translation_ab_pred = torch.zeros(3,
                                       device=src.device,
                                       dtype=torch.float32).view(
                                           1, 3).repeat(batch_size, 1)
     total_loss = 0
     temp = torch.tensor(temp).cuda().repeat(batch_size)
     for i in range(self.num_iters):
         rotation_ab_pred_i, translation_ab_pred_i, scores = self.forward(
             src, tgt, temp, i)
         # 残差位姿
         res_rotation_ab = torch.matmul(rotation_ab,
                                        rotation_ab_pred.transpose(2, 1))
         res_translation_ab = translation_ab - torch.matmul(
             res_rotation_ab, translation_ab_pred.unsqueeze(2)).squeeze(2)
         # 累计位姿
         rotation_ab_pred = torch.matmul(rotation_ab_pred_i,
                                         rotation_ab_pred)
         translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                               + translation_ab_pred_i
         # 熵值loss
         entropy_loss = self.compute_loss(scores, src, res_rotation_ab,
                                          res_translation_ab, tgt)
         total_loss = total_loss + entropy_loss
         src = transform_point_cloud(src, rotation_ab_pred_i,
                                     translation_ab_pred_i)
     return total_loss.item(), rotation_ab_pred, translation_ab_pred
Beispiel #10
0
    def forward(self, *input):
        src = input[0]
        tgt = input[1]
        src_embedding = self.emb_nn(src)
        tgt_embedding = self.emb_nn(tgt)

        if self.iterations > 1:
            rotation_ab_temp, translation_ab_temp = self.head(
                src_embedding, tgt_embedding, src, tgt)
            rotation_ab, translation_ab = rotation_ab_temp.clone(
            ), translation_ab_temp.clone()

            for itr in range(self.iterations - 1):
                src = transform_point_cloud(src, rotation_ab_temp,
                                            translation_ab_temp)
                src_embedding = self.emb_nn(src)

                rotation_ab_temp, translation_ab_temp = self.head(
                    src_embedding, tgt_embedding, src, tgt)
                rotation_ab, translation_ab = combine_transformations(
                    rotation_ab_temp, translation_ab_temp, rotation_ab,
                    translation_ab)

            src = transform_point_cloud(src, rotation_ab_temp,
                                        translation_ab_temp)
        else:
            rotation_ab, translation_ab = self.head(src_embedding,
                                                    tgt_embedding, src, tgt)
            src = transform_point_cloud(src, rotation_ab, translation_ab)

        if self.cycle:
            rotation_ba, translation_ba = self.head(tgt_embedding,
                                                    src_embedding, tgt, src)
        else:
            rotation_ba = rotation_ab.transpose(2, 1).contiguous()
            translation_ba = -torch.matmul(
                rotation_ba, translation_ab.unsqueeze(2)).squeeze(2)

        dist1, dist2 = self.chamfer(src.permute(0, 2, 1), tgt.permute(0, 2, 1))
        loss = (torch.mean(torch.sqrt(dist1)) +
                torch.mean(torch.sqrt(dist2))) / 2.0
        return rotation_ab, translation_ab, rotation_ba, translation_ba, loss
Beispiel #11
0
def test_one_pair(args, net):
    source, template = read_data()
    net.eval()
    src = source.to(args.device)
    target = template.to(args.device)
    src = src.permute(0, 2, 1)
    target = target.permute(0, 2, 1)
    batch_size = src.size(0)
    rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred, _ = net(
        src, target)
    transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                            translation_ab_pred)
    transformed_target = transform_point_cloud(target, rotation_ba_pred,
                                               translation_ba_pred)

    transformed_src = transformed_src.permute(0, 2, 1)
    transformed_target = transformed_target.permute(0, 2, 1)
    print(rotation_ab_pred, translation_ab_pred)

    return source.numpy(), template.numpy(), transformed_target.cpu().detach(
    ).numpy(), rotation_ab_pred, translation_ab_pred
Beispiel #12
0
 def forward(self, *input):
     src = input[0]
     tgt = input[1]
     temp = input[2]
     rotation_ab, translation_ab, scores = self.match_net_0(src, tgt, temp)
     src = transform_point_cloud(src, rotation_ab, translation_ab)
     tgt_embedding = self.tgt_emb_nn(tgt)
     src_embedding = self.src_emb_nn(src)
     src_embedding = self.src_attn(src_embedding, src_embedding)
     tgt_embedding = self.tgt_attn(tgt_embedding, tgt_embedding)
     rotation_ab, translation_ab, scores = self.head(
         src_embedding, tgt_embedding, src, tgt, temp)
     return rotation_ab, translation_ab, scores
Beispiel #13
0
 def compute_loss(self, scores, src, rotation_ab, translation_ab, tgt):
     src_gt = transform_point_cloud(src, rotation_ab, translation_ab)
     # view_pointclouds(src_k_gt.squeeze(0).cpu().detach().numpy().T, tgt.squeeze(0).cpu().detach().numpy().T)
     dists = pairwise_distance(src_gt, tgt)
     # (bs, k, np)
     sort_distance, sort_id = torch.sort(dists, dim=-1)
     # (bs, k, 1) 距离最近的id 设阈值小于0.1的为关键点
     TD = sort_id[:, :, 0, None]
     # (bs, np, 1)
     nearest_dist = sort_distance[:, :, 0, None]
     S = torch.gather(-torch.log(scores + 1e-8), index=TD, dim=-1)
     S_loss = torch.mean(S)
     return S_loss
Beispiel #14
0
    def predict(self, src, tgt, n_iters=3):
        batch_size = src.size(0)
        rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
        translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)
        for i in range(n_iters):
            rotation_ab_pred_i, translation_ab_pred_i, rotation_ba_pred_i, translation_ba_pred_i, _ \
                = self.forward(src, tgt)
            rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred)
            translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                                  + translation_ab_pred_i
            src = transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i)

        return rotation_ab_pred, translation_ab_pred
Beispiel #15
0
 def sampling_loss(self, src, rotation_ab, translation_ab, tgt,
                   corr_scores):
     src_gt = transform_point_cloud(src, rotation_ab, translation_ab)
     dists = pairwise_distance(src_gt, tgt)
     sort_distance, sort_id = torch.sort(dists, dim=-1)
     # (bs, np, 1)
     TD = sort_id[:, :, 0, None]
     # (bs, np, 1)
     nearest_dist = sort_distance[:, :, 0, None]
     # (bs, np, 1)
     S_zeros = torch.zeros_like(corr_scores)
     ind_S = torch.where(nearest_dist > 0.08, S_zeros, -corr_scores)
     dists_loss = torch.mean(ind_S)
     return dists_loss
def test_one_epoch(args, net, test_loader):
    net.eval()

    total_loss = 0
    num_examples = 0

    if args.eval_full:
        target_full = None
        pred_transformed = None
        src_full = None

    for src, target, gt_flow in tqdm(test_loader):
        src = src.cuda()
        target = target.cuda()

        batch_size = src.size(0)
        num_examples += batch_size
        rotation_ab_pred, translation_ab_pred = net(src, target)

        ###########################
        transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                translation_ab_pred)
        loss = EPE(transformed_src, target)

        total_loss += loss.item() * batch_size

        if args.eval_full:
            if target_full is None:
                target_full = target.cpu().detach().numpy()
                pred_transformed = transformed_src.cpu().detach().numpy()
                src_full = src.cpu().detach().numpy()
            else:
                target_full = np.concatenate(
                    [target_full, target.cpu().detach().numpy()], axis=2)
                pred_transformed = np.concatenate(
                    [pred_transformed,
                     transformed_src.cpu().detach().numpy()],
                    axis=2)
                src_full = np.concatenate(
                    [src_full, src.cpu().detach().numpy()], axis=2)

    if args.display_scene_flow and args.eval and args.eval_full:
        visualize_transformed(src_full.squeeze(), target_full.squeeze(),
                              pred_transformed.squeeze())
    elif args.display_scene_flow and args.eval:
        visualize_transformed(src.cpu().detach().numpy(),
                              target.cpu().detach().numpy(),
                              transformed_src.cpu().detach().numpy())

    return total_loss * 1.0 / num_examples
Beispiel #17
0
 def forward(self, *input):
     src = input[0]
     tgt = input[1]
     temp = input[2]
     rotation_ab, translation_ab, scores = self.match_net_1(src, tgt, temp)
     src = transform_point_cloud(src, rotation_ab, translation_ab)
     # fine-tuning
     tgt_embedding = self.tgt_emb_nn(tgt)
     src_embedding = self.src_emb_nn(src)
     src_embedding, tgt_embedding = self.share_attn(src_embedding,
                                                    tgt_embedding,
                                                    cross=True)
     rotation_ab, translation_ab, scores = self.head(
         src_embedding, tgt_embedding, src, tgt, temp)
     return rotation_ab, translation_ab, scores
Beispiel #18
0
    def _train_one_batch(self, src, tgt, rotation_ab, translation_ab, opt,
                         temp):
        opt.zero_grad()
        batch_size = src.size(0)
        identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(
            batch_size, 1, 1)

        rotation_ab_pred = torch.eye(3, device=src.device,
                                     dtype=torch.float32).view(1, 3, 3).repeat(
                                         batch_size, 1, 1)
        translation_ab_pred = torch.zeros(3,
                                          device=src.device,
                                          dtype=torch.float32).view(
                                              1, 3).repeat(batch_size, 1)
        total_loss = 0
        with torch.no_grad():
            distance = pairwise_distance(src, src)
            sort_distance, _ = torch.sort(distance, dim=-1)
            nearest_distance = sort_distance[:, :, 1].squeeze()
            median = torch.median(nearest_distance, dim=-1)[0]
            meanDist = torch.mean(median)
            sigmal_ = meanDist * self.sigma_times
            sigmal_ = sigmal_.repeat(batch_size)
        temp = torch.tensor(temp).cuda().repeat(batch_size)
        res_rotation_ab = rotation_ab
        res_translation_ab = translation_ab
        for i in range(self.num_iters):
            sigmal_ = sigmal_ * self.DecayPram
            rotation_ab_pred_i, translation_ab_pred_i = self.forward(
                src, tgt, temp, i, sigmal_)
            rotation_ab_pred = torch.matmul(rotation_ab_pred_i,
                                            rotation_ab_pred)
            translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                                  + translation_ab_pred_i
            mcc_loss = self.mcc_loss(src, rotation_ab_pred_i,
                                     translation_ab_pred_i, res_rotation_ab,
                                     res_translation_ab, sigmal_)
            total_loss = total_loss + mcc_loss * self.discount_factor**i
            res_rotation_ab = torch.matmul(rotation_ab,
                                           rotation_ab_pred.transpose(2, 1))
            res_translation_ab = translation_ab - torch.matmul(
                res_rotation_ab, translation_ab_pred.unsqueeze(2)).squeeze(2)
            src = transform_point_cloud(src, rotation_ab_pred_i,
                                        translation_ab_pred_i)
        total_loss.backward()
        opt.step()
        return total_loss.item(), rotation_ab_pred, translation_ab_pred
Beispiel #19
0
    def _train_one_batch(self, src, tgt, rotation_ab, translation_ab, opt):
        opt.zero_grad()
        batch_size = src.size(0)
        identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(batch_size, 1, 1)

        rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
        translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)

        rotation_ba_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1)
        translation_ba_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1)

        total_loss = 0
        total_feature_alignment_loss = 0
        total_cycle_consistency_loss = 0
        total_scale_consensus_loss = 0
        for i in range(self.num_iters):
            rotation_ab_pred_i, translation_ab_pred_i, rotation_ba_pred_i, translation_ba_pred_i, \
            feature_disparity = self.forward(src, tgt)

            # 累计求位姿
            rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred)
            translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                                  + translation_ab_pred_i
            rotation_ba_pred = torch.matmul(rotation_ba_pred_i, rotation_ba_pred)
            translation_ba_pred = torch.matmul(rotation_ba_pred_i, translation_ba_pred.unsqueeze(2)).squeeze(2) \
                                  + translation_ba_pred_i

            # 本次迭代的位姿损失
            loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                   + F.mse_loss(translation_ab_pred, translation_ab)) * self.discount_factor**i
            # 全局特征对齐
            feature_alignment_loss = feature_disparity.mean() * self.feature_alignment_loss * self.discount_factor**i
            # 往返特征对齐
            cycle_consistency_loss = cycle_consistency(rotation_ab_pred_i, translation_ab_pred_i,
                                                       rotation_ba_pred_i, translation_ba_pred_i) \
                                     * self.cycle_consistency_loss * self.discount_factor**i
            scale_consensus_loss = 0
            total_feature_alignment_loss += feature_alignment_loss
            total_cycle_consistency_loss += cycle_consistency_loss
            total_loss = total_loss + loss + feature_alignment_loss + cycle_consistency_loss + scale_consensus_loss
            # 更新源点云
            src = transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i)
        total_loss.backward()
        opt.step()
        return total_loss.item(), total_feature_alignment_loss.item(), total_cycle_consistency_loss.item(), \
               total_scale_consensus_loss, rotation_ab_pred, translation_ab_pred
Beispiel #20
0
 def compute_loss(self, scores, src_k, rotation_ab, translation_ab, tgt):
     src_k_gt = transform_point_cloud(src_k, rotation_ab, translation_ab)
     # view_pointclouds(src_k_gt.squeeze(0).cpu().detach().numpy().T, tgt.squeeze(0).cpu().detach().numpy().T)
     dists = pairwise_distance(src_k_gt, tgt)
     # (bs, k, np)
     sort_distance, sort_id = torch.sort(dists, dim=-1)
     # (bs, k, 1) 距离最近的id 设阈值小于0.1的为关键点
     TD = sort_id[:, :, 0, None]
     # (bs, k, 1)
     nearest_dist = sort_distance[:, :, 0, None]
     # (bs, k, 1)
     S = -torch.sum(
         torch.mul(scores, torch.log(scores + 1e-8)), dim=-1, keepdim=True)
     S_zeros = torch.zeros_like(S)
     # 超参需要手动调整
     ind_S = torch.where(nearest_dist > 0.08, S_zeros, S)
     S_loss = torch.mean(ind_S)
     return S_loss
Beispiel #21
0
    def _train_one_batch(self, src, tgt, rotation_ab, translation_ab, opt,
                         temp):
        opt.zero_grad()
        batch_size = src.size(0)
        identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(
            batch_size, 1, 1)

        rotation_ab_pred = torch.eye(3, device=src.device,
                                     dtype=torch.float32).view(1, 3, 3).repeat(
                                         batch_size, 1, 1)
        translation_ab_pred = torch.zeros(3,
                                          device=src.device,
                                          dtype=torch.float32).view(
                                              1, 3).repeat(batch_size, 1)
        total_loss = 0
        temp = torch.tensor(temp).cuda().repeat(batch_size)
        for i in range(self.num_iters):
            rotation_ab_pred_i, translation_ab_pred_i, scores, src_k = self.forward(
                src, tgt, temp, i)
            # 此时src_k 没有发生旋转
            res_rotation_ab = torch.matmul(rotation_ab,
                                           rotation_ab_pred.transpose(2, 1))
            res_translation_ab = translation_ab - torch.matmul(
                res_rotation_ab, translation_ab_pred.unsqueeze(2)).squeeze(2)
            # 累计位姿
            rotation_ab_pred = torch.matmul(rotation_ab_pred_i,
                                            rotation_ab_pred)
            translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) \
                                  + translation_ab_pred_i
            # 位姿loss
            loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity)\
                    + F.mse_loss(translation_ab_pred, translation_ab)) * self.discount_factor ** i
            # 关键点loss
            entropy_loss = self.compute_loss(scores, src_k, res_rotation_ab,
                                             res_translation_ab,
                                             tgt) * self.discount_factor**i
            # 总loss
            total_loss = total_loss + loss + entropy_loss * 0.2
            # 这个时候点云才会变
            src = transform_point_cloud(src, rotation_ab_pred_i,
                                        translation_ab_pred_i)
        total_loss.backward()
        opt.step()
        return total_loss.item(), rotation_ab_pred, translation_ab_pred
Beispiel #22
0
def compute_scores_gt(sampling_scores, src, tgt, rotation_ab, translation_ab):
    bs, k, num_points = sampling_scores.size()
    src_corr = transform_point_cloud(src, rotation_ab, translation_ab)
    inner = -2 * torch.matmul(src_corr.transpose(2, 1).contiguous(), tgt)
    xx = torch.sum(src_corr**2, dim=1, keepdim=True)
    yy = torch.sum(tgt**2, dim=1, keepdim=True)
    distance = xx.transpose(2, 1).contiguous() + inner + yy
    nearst_dist, _ = distance.sort(dim=-1)
    # (bs, np)
    nearst_dist = nearst_dist[:, :, 0].squeeze(-1)
    idx = nearst_dist.sort(dim=1)[1]
    # (bs, k)
    idx_k = idx[:, :k]
    gt_scores = torch.zeros((bs, k, num_points), device=sampling_scores.device)
    # (bs, k) -> (bs, k, np)
    for o in range(bs):
        for i in range(k):
            gt_scores[o, i, idx_k[o][i]] = 1
    return gt_scores
Beispiel #23
0
def test_one_epoch(args, net, test_loader):
    net.eval()
    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss = 0
    total_cycle_loss = 0
    num_examples = 0
    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba in tqdm(
            test_loader):
        src = src.cuda()
        target = target.cuda()
        rotation_ab = rotation_ab.cuda()
        translation_ab = translation_ab.cuda()
        rotation_ba = rotation_ba.cuda()
        translation_ba = translation_ba.cuda()

        batch_size = src.size(0)
        num_examples += batch_size
        rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(
            src, target)

        ## save rotation and translation
        rotations_ab.append(rotation_ab.detach().cpu().numpy())
        translations_ab.append(translation_ab.detach().cpu().numpy())
        rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
        translations_ab_pred.append(translation_ab_pred.detach().cpu().numpy())
        eulers_ab.append(euler_ab.numpy())
        ##
        rotations_ba.append(rotation_ba.detach().cpu().numpy())
        translations_ba.append(translation_ba.detach().cpu().numpy())
        rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
        translations_ba_pred.append(translation_ba_pred.detach().cpu().numpy())
        eulers_ba.append(euler_ba.numpy())

        transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                translation_ab_pred)

        transformed_target = transform_point_cloud(target, rotation_ba_pred,
                                                   translation_ba_pred)

        ###########################
        identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
        loss = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
               + F.mse_loss(translation_ab_pred, translation_ab)
        if args.cycle:
            rotation_loss = F.mse_loss(
                torch.matmul(rotation_ba_pred, rotation_ab_pred),
                identity.clone())
            translation_loss = torch.mean(
                (torch.matmul(rotation_ba_pred.transpose(2, 1),
                              translation_ab_pred.view(batch_size, 3, 1)).view(
                                  batch_size, 3) + translation_ba_pred)**2,
                dim=[0, 1])
            cycle_loss = rotation_loss + translation_loss

            loss = loss + cycle_loss * 0.1

        total_loss += loss.item() * batch_size

        if args.cycle:
            total_cycle_loss = total_cycle_loss + cycle_loss.item(
            ) * 0.1 * batch_size

        mse_ab += torch.mean(
            (transformed_src - target)**2, dim=[0, 1, 2]).item() * batch_size
        mae_ab += torch.mean(torch.abs(transformed_src - target),
                             dim=[0, 1, 2]).item() * batch_size

        mse_ba += torch.mean(
            (transformed_target - src)**2, dim=[0, 1, 2]).item() * batch_size
        mae_ba += torch.mean(torch.abs(transformed_target - src),
                             dim=[0, 1, 2]).item() * batch_size

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
           mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
           mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
           translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
           translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba
Beispiel #24
0
def test_one_epoch(args, net, test_loader):
    net.eval()
    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss = 0
    total_cycle_loss = 0
    num_examples = 0
    num_correct_corr = 0
    num_total_corr = 0

    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba, gt_corr_mat in tqdm(test_loader):
        src = src.cuda()
        target = target.cuda()
        rotation_ab = rotation_ab.cuda()
        translation_ab = translation_ab.cuda()
        rotation_ba = rotation_ba.cuda()
        translation_ba = translation_ba.cuda()
        gt_corr_mat = gt_corr_mat.cuda()

        batch_size = src.size(0)
        num_examples += batch_size
        rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred, raw_scores = net(src, target)

        # accuracy in correspondences
        corr_mat_pred = raw_scores.detach().cpu().numpy()     # b,m,n    
        col_idx_pred = np.argmax(corr_mat_pred,axis=-1) 
        corr_mat_gt = gt_corr_mat.detach().cpu().numpy()     # b, m (scr), n (tgt)  
        col_idx_gt = np.argmax(corr_mat_gt,axis=-1)
        correct_mask = col_idx_gt == col_idx_pred
        num_correct_corr += correct_mask.sum()
        num_total_corr += np.prod(correct_mask.shape)

        ## save rotation and translation
        rotations_ab.append(rotation_ab.detach().cpu().numpy())
        translations_ab.append(translation_ab.detach().cpu().numpy())
        rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
        translations_ab_pred.append(translation_ab_pred.detach().cpu().numpy())
        eulers_ab.append(euler_ab.numpy())
        ##
        rotations_ba.append(rotation_ba.detach().cpu().numpy())
        translations_ba.append(translation_ba.detach().cpu().numpy())
        rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
        translations_ba_pred.append(translation_ba_pred.detach().cpu().numpy())
        eulers_ba.append(euler_ba.numpy())

        transformed_src = transform_point_cloud(src, rotation_ab_pred, translation_ab_pred)

        transformed_target = transform_point_cloud(target, rotation_ba_pred, translation_ba_pred)

        ###########################
        identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
        # loss = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
        #        + F.mse_loss(translation_ab_pred, translation_ab)
        loss = CorrLoss()(target, src, raw_scores, gt_corr_mat)
        if args.cycle:
            rotation_loss = F.mse_loss(torch.matmul(rotation_ba_pred, rotation_ab_pred), identity.clone())
            translation_loss = torch.mean((torch.matmul(rotation_ba_pred.transpose(2, 1),
                                                        translation_ab_pred.view(batch_size, 3, 1)).view(batch_size, 3)
                                           + translation_ba_pred) ** 2, dim=[0, 1])
            cycle_loss = rotation_loss + translation_loss

            # loss = loss + cycle_loss * 0.1

        total_loss += loss.item() * batch_size

        if args.cycle:
            total_cycle_loss = total_cycle_loss + cycle_loss.item() * 0.1 * batch_size

        mse_ab += torch.mean((transformed_src - target) ** 2, dim=[0, 1, 2]).item() * batch_size
        mae_ab += torch.mean(torch.abs(transformed_src - target), dim=[0, 1, 2]).item() * batch_size

        mse_ba += torch.mean((transformed_target - src) ** 2, dim=[0, 1, 2]).item() * batch_size
        mae_ba += torch.mean(torch.abs(transformed_target - src), dim=[0, 1, 2]).item() * batch_size

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    corr_accuracy = (num_correct_corr / num_total_corr)*100

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
           mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
           mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
           translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
           translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba, corr_accuracy
Beispiel #25
0
def train_one_epoch(args, net, train_loader, opt):
    net.train()
    # net.double()

    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss = 0
    total_cycle_loss = 0
    num_examples = 0
    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    # count =1
    for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba in tqdm(
            train_loader):

        # # count=0
        # src_dyn = np.ones((32,3,1024+count))*0.9
        # trgt_dyn =  np.ones((32,3,1024+count))*0.5
        # src = torch.from_numpy(src_dyn).float().cuda()
        # target = torch.from_numpy(trgt_dyn).float().cuda()
        # count = count +1
        # print(src.shape,target.shape)
        src = src.cuda()
        target = target.cuda()
        rotation_ab = rotation_ab.cuda()
        translation_ab = translation_ab.cuda()
        rotation_ba = rotation_ba.cuda()
        translation_ba = translation_ba.cuda()

        batch_size = src.size(0)
        opt.zero_grad()
        num_examples += batch_size
        rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(
            src, target)

        ## save rotation and translation
        rotations_ab.append(rotation_ab.detach().cpu().numpy())
        translations_ab.append(translation_ab.detach().cpu().numpy())
        rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
        translations_ab_pred.append(translation_ab_pred.detach().cpu().numpy())
        eulers_ab.append(euler_ab.numpy())
        ##
        rotations_ba.append(rotation_ba.detach().cpu().numpy())
        translations_ba.append(translation_ba.detach().cpu().numpy())
        rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
        translations_ba_pred.append(translation_ba_pred.detach().cpu().numpy())
        eulers_ba.append(euler_ba.numpy())

        transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                translation_ab_pred)

        transformed_target = transform_point_cloud(target, rotation_ba_pred,
                                                   translation_ba_pred)
        ###########################
        identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
        loss = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
               + F.mse_loss(translation_ab_pred, translation_ab)
        if args.cycle:
            rotation_loss = F.mse_loss(
                torch.matmul(rotation_ba_pred, rotation_ab_pred),
                identity.clone())
            translation_loss = torch.mean(
                (torch.matmul(rotation_ba_pred.transpose(2, 1),
                              translation_ab_pred.view(batch_size, 3, 1)).view(
                                  batch_size, 3) + translation_ba_pred)**2,
                dim=[0, 1])
            cycle_loss = rotation_loss + translation_loss

            loss = loss + cycle_loss * 0.1

        loss.backward()
        opt.step()
        total_loss += loss.item() * batch_size

        if args.cycle:
            total_cycle_loss = total_cycle_loss + cycle_loss.item(
            ) * 0.1 * batch_size

        mse_ab += torch.mean(
            (transformed_src - target)**2, dim=[0, 1, 2]).item() * batch_size
        mae_ab += torch.mean(torch.abs(transformed_src - target),
                             dim=[0, 1, 2]).item() * batch_size

        mse_ba += torch.mean(
            (transformed_target - src)**2, dim=[0, 1, 2]).item() * batch_size
        mae_ba += torch.mean(torch.abs(transformed_target - src),
                             dim=[0, 1, 2]).item() * batch_size

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
           mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
           mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
           translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
           translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba
Beispiel #26
0
def test_one_epoch(args, net, test_loader):
    net.eval()

    # initialization
    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss = 0
    total_cycle_loss = 0
    num_examples = 0
    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    batch_idx = 0
    total_correct_pred = 0
    itr = 0
    if args.debug:
        ang_error_list = []
    for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba, col_idx, corr_mat_ab in tqdm(
            test_loader):
        batch_size = src.size(0)
        num_points = src.size(-1)
        num_points_target = target.size(-1)

        if args.debug:  # if degubbing
            for i in range(batch_size):
                np.savetxt(
                    "variables_storage/src_batch_{}_sample_{}".format(
                        batch_idx, i), src[i, :, :])
                np.savetxt(
                    "variables_storage/target_batch_{}_sample_{}".format(
                        batch_idx, i), target[i, :, :])
                np.savetxt(
                    "variables_storage/rotation_ab_batch_{}_sample_{}".format(
                        batch_idx, i), rotation_ab[i, :, :])
                np.savetxt(
                    "variables_storage/translation_ab_batch_{}_sample_{}".
                    format(batch_idx, i), translation_ab[i, :])
                np.savetxt(
                    "variables_storage/euler_ab_batch_{}_sample_{}".format(
                        batch_idx, i), euler_ab[i, :])
                np.savetxt(
                    "variables_storage/col_idx_batch_{}_sample_{}".format(
                        batch_idx, i), col_idx[i, :])
            if batch_idx >= 100:
                break

        src = src.cuda()
        target = target.cuda()
        rotation_ab = rotation_ab.cuda()
        translation_ab = translation_ab.cuda()
        rotation_ba = rotation_ba.cuda()
        translation_ba = translation_ba.cuda()
        col_idx = col_idx.cuda()
        corr_mat_ab = corr_mat_ab.cuda()

        num_examples += batch_size

        # model output
        rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred, corr_mat_ab_pred = net(
            src, target)

        ## save rotation and translation
        rotations_ab.append(rotation_ab.detach().cpu().numpy())
        translations_ab.append(translation_ab.detach().cpu().numpy())
        rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
        translations_ab_pred.append(translation_ab_pred.detach().cpu().numpy())
        eulers_ab.append(euler_ab.numpy())

        rotations_ba.append(rotation_ba.detach().cpu().numpy())
        translations_ba.append(translation_ba.detach().cpu().numpy())
        rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
        translations_ba_pred.append(translation_ba_pred.detach().cpu().numpy())
        eulers_ba.append(euler_ba.numpy())

        # transforming the point cloud according to given rotation and translation
        transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                translation_ab_pred)

        transformed_target = transform_point_cloud(target, rotation_ba_pred,
                                                   translation_ba_pred)

        identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)

        #  correspondence loss, as proposed in our paper
        if args.loss == 'cross_entropy_corr':
            # corr_mat_ab: ground truth correspondence matrix
            # corr_mat_ab_pred: predicted correspondence matrix
            loss_corr = F.cross_entropy(
                corr_mat_ab_pred.view(batch_size * num_points,
                                      num_points_target),
                torch.argmax(corr_mat_ab.transpose(1, 2).reshape(
                    -1, num_points_target),
                             axis=1))

            loss_transf = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                + F.mse_loss(translation_ab_pred, translation_ab)

            loss = loss_corr

        # translation loss, as proposed in DCP
        elif args.loss == 'mse_transf':
            loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
               + F.mse_loss(translation_ab_pred, translation_ab))
        else:
            raise Exception("please verify the input loss function")

        if args.cycle:
            raise Exception("cycle for corr_mat_ab not implemented yet")
            rotation_loss = F.mse_loss(
                torch.matmul(rotation_ba_pred, rotation_ab_pred),
                identity.clone())
            translation_loss = torch.mean(
                (torch.matmul(rotation_ba_pred.transpose(2, 1),
                              translation_ab_pred.view(batch_size, 3, 1)).view(
                                  batch_size, 3) + translation_ba_pred)**2,
                dim=[0, 1])
            cycle_loss = rotation_loss + translation_loss

            loss = loss + cycle_loss * 0.1

        total_loss += loss.item() * batch_size

        if args.cycle:
            total_cycle_loss = total_cycle_loss + cycle_loss.item(
            ) * 0.1 * batch_size

        gt_idx = torch.argmax(
            corr_mat_ab.transpose(1, 2).reshape(-1, num_points_target),
            axis=1)  # ground-truth index of the corresponding target point
        pred_idx = torch.argmax(
            corr_mat_ab_pred.view(-1, num_points_target),
            axis=1)  # predicted index of the corresponding target point

        # if the indices match, then the predicted corresponding target point is correct
        correct_pred_idx = torch.where(gt_idx - pred_idx == 0)
        total_correct_pred += len(correct_pred_idx[0])

        try:
            mse_ab += torch.mean((transformed_src - target)**2,
                                 dim=[0, 1, 2]).item() * batch_size
            mae_ab += torch.mean(torch.abs(transformed_src - target),
                                 dim=[0, 1, 2]).item() * batch_size

            mse_ba += torch.mean((transformed_target - src)**2,
                                 dim=[0, 1, 2]).item() * batch_size
            mae_ba += torch.mean(torch.abs(transformed_target - src),
                                 dim=[0, 1, 2]).item() * batch_size
        except:  # in partial point cloud case
            mse_ab += 0
            mae_ab += 0
            mse_ba += 0
            mae_ba += 0

        if args.debug:
            corr_mat_ab_pred_np = torch.clone(
                corr_mat_ab_pred).detach().cpu().numpy()
            corr_mat_ab_gt_np = torch.clone(corr_mat_ab).detach().cpu().numpy()
            rotation_ab_pred_np = torch.clone(
                rotation_ab_pred).detach().cpu().numpy()
            translation_ab_pred_np = torch.clone(
                translation_ab_pred).detach().cpu().numpy()
            col_idx_pred = np.argmax(corr_mat_ab_pred_np, axis=1)

            for i in range(batch_size):
                np.savetxt(
                    "variables_storage/corr_mat_ab_pred_batch_{}_sample_{}".
                    format(batch_idx, i), corr_mat_ab_pred_np[i, :, :])
                np.savetxt(
                    "variables_storage/col_idx_pred_batch_{}_sample_{}".format(
                        batch_idx, i), col_idx_pred[i, :])
                np.savetxt(
                    "variables_storage/rotation_ab_pred_batch_{}_sample_{}".
                    format(batch_idx, i), rotation_ab_pred_np[i, :])
                np.savetxt(
                    "variables_storage/corr_mat_ab_gt_{}_sample_{}".format(
                        batch_idx, i), corr_mat_ab_gt_np[i, :, :])

        itr += 1
        batch_idx += 1

    # computing percentage of incorrect point correspondences
    incorrect_correspondences = (1 - total_correct_pred /
                                 (num_examples * num_points)) * 100

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
        mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
        mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
        translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
        translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba,\
        incorrect_correspondences
Beispiel #27
0
def train_one_epoch(args, net, train_loader, opt):
    net.train()
    global epoch_COUNT
    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss = 0
    total_loss_dcp_rot = 0
    total_loss_dcp_t = 0
    total_cycle_loss = 0
    num_examples = 0
    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    total_correct_pred = 0
    itr = 0
    for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba, col_idx, corr_mat_ab in tqdm(
            train_loader):
        src = src.cuda()
        target = target.cuda()
        rotation_ab = rotation_ab.cuda()
        translation_ab = translation_ab.cuda()
        rotation_ba = rotation_ba.cuda()
        translation_ba = translation_ba.cuda()
        col_idx = col_idx.cuda()
        corr_mat_ab = corr_mat_ab.cuda()

        batch_size = src.size(0)
        num_points = src.size(-1)
        num_points_target = target.size(-1)

        opt.zero_grad()
        num_examples += batch_size

        # model output
        rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred, corr_mat_ab_pred = net(
            src, target)

        ## save rotation and translation
        rotations_ab.append(rotation_ab.detach().cpu().numpy())
        translations_ab.append(translation_ab.detach().cpu().numpy())
        rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
        translations_ab_pred.append(translation_ab_pred.detach().cpu().numpy())
        eulers_ab.append(euler_ab.numpy())

        rotations_ba.append(rotation_ba.detach().cpu().numpy())
        translations_ba.append(translation_ba.detach().cpu().numpy())
        rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
        translations_ba_pred.append(translation_ba_pred.detach().cpu().numpy())
        eulers_ba.append(euler_ba.numpy())

        # transforming the point cloud according to given rotation and translation
        transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                translation_ab_pred)

        transformed_target = transform_point_cloud(target, rotation_ba_pred,
                                                   translation_ba_pred)

        identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)

        loss_dcp_rot = F.mse_loss(
            torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab),
            identity)
        loss_dcp_t = F.mse_loss(translation_ab_pred, translation_ab)

        #  correspondence loss, as proposed in our paper
        if args.loss == 'cross_entropy_corr':
            loss_corr = F.cross_entropy(
                corr_mat_ab_pred.view(batch_size * num_points,
                                      num_points_target),
                torch.argmax(corr_mat_ab.transpose(1, 2).reshape(
                    -1, num_points_target),
                             axis=1))

            loss_transf = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
                + F.mse_loss(translation_ab_pred, translation_ab)

            loss = loss_corr

        # translation loss, as proposed in DCP
        elif args.loss == 'mse_transf':
            loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
               + F.mse_loss(translation_ab_pred, translation_ab))
        else:
            raise Exception("please verify the input loss function")

        if args.cycle:
            raise Exception("cycle for corr_mat_ab not implemented yet")
            rotation_loss = F.mse_loss(
                torch.matmul(rotation_ba_pred, rotation_ab_pred),
                identity.clone())
            translation_loss = torch.mean(
                (torch.matmul(rotation_ba_pred.transpose(2, 1),
                              translation_ab_pred.view(batch_size, 3, 1)).view(
                                  batch_size, 3) + translation_ba_pred)**2,
                dim=[0, 1])
            cycle_loss = rotation_loss + translation_loss

            loss = loss + cycle_loss * 0.1

        loss.backward()
        opt.step()
        total_loss += loss.item() * batch_size
        total_loss_dcp_rot += loss_dcp_rot.item() * batch_size
        total_loss_dcp_t += loss_dcp_t.item() * batch_size

        gt_idx = torch.argmax(
            corr_mat_ab.transpose(1, 2).reshape(-1, num_points_target),
            axis=1)  # ground-truth index of the corresponding target point
        pred_idx = torch.argmax(
            corr_mat_ab_pred.view(-1, num_points_target),
            axis=1)  # predicted index of the corresponding target point

        # if the indices match, then the predicted corresponding target point is correct
        correct_pred_idx = torch.where(gt_idx - pred_idx == 0)
        total_correct_pred += len(correct_pred_idx[0])

        if args.cycle:
            total_cycle_loss = total_cycle_loss + cycle_loss.item(
            ) * 0.1 * batch_size

        try:
            mse_ab += torch.mean((transformed_src - target)**2,
                                 dim=[0, 1, 2]).item() * batch_size
            mae_ab += torch.mean(torch.abs(transformed_src - target),
                                 dim=[0, 1, 2]).item() * batch_size

            mse_ba += torch.mean((transformed_target - src)**2,
                                 dim=[0, 1, 2]).item() * batch_size
            mae_ba += torch.mean(torch.abs(transformed_target - src),
                                 dim=[0, 1, 2]).item() * batch_size
        except:  # in partial point cloud case
            mse_ab += 0
            mae_ab += 0
            mse_ba += 0
            mae_ba += 0

        itr += 1

    # computing percentage of incorrect point correspondences
    incorrect_correspondences = (1 - total_correct_pred /
                                 (num_examples * num_points)) * 100

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
        mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
        mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
        translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
        translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba,\
        incorrect_correspondences
Beispiel #28
0
def train_one_epoch(args, net, train_loader, opt):
	net.train()

	mse_ab = 0
	mae_ab = 0
	mse_ba = 0
	mae_ba = 0

	total_loss = 0
	total_cycle_loss = 0
	num_examples = 0
	rotations_ab = []
	translations_ab = []
	rotations_ab_pred = []
	translations_ab_pred = []

	rotations_ba = []
	translations_ba = []
	rotations_ba_pred = []
	translations_ba_pred = []

	eulers_ab = []
	eulers_ba = []

	for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba in tqdm(train_loader):
		# src -> pointcloud1
		# target -> pointcloud2
		# pointcloud2 = rotation_ab*pointcloud1 + translation_ab
		# target = rotation_ab*src +_translation_ab

		src = src.cuda()
		target = target.cuda()
		rotation_ab = rotation_ab.cuda()
		translation_ab = translation_ab.cuda()
		rotation_ba = rotation_ba.cuda()
		translation_ba = translation_ba.cuda()

		batch_size = src.size(0)
		opt.zero_grad()
		num_examples += batch_size
		rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred, loss = net(src, target)

		## save rotation and translation
		rotations_ab.append(rotation_ab.detach().cpu().numpy())
		translations_ab.append(translation_ab.detach().cpu().numpy())
		rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
		translations_ab_pred.append(translation_ab_pred.detach().cpu().numpy())
		eulers_ab.append(euler_ab.numpy())
		##
		rotations_ba.append(rotation_ba.detach().cpu().numpy())
		translations_ba.append(translation_ba.detach().cpu().numpy())
		rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
		translations_ba_pred.append(translation_ba_pred.detach().cpu().numpy())
		eulers_ba.append(euler_ba.numpy())

		transformed_src = transform_point_cloud(src, rotation_ab_pred, translation_ab_pred)

		transformed_target = transform_point_cloud(target, rotation_ba_pred, translation_ba_pred)
		
		###########################
		if args.loss == 'frobenius_norm':
			identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
			loss = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
				   + F.mse_loss(translation_ab_pred, translation_ab)
		

		loss.backward()
		opt.step()
		total_loss += loss.item() * batch_size

		if args.cycle:
			total_cycle_loss = total_cycle_loss + cycle_loss.item() * 0.1 * batch_size

		mse_ab += torch.mean((transformed_src - target) ** 2, dim=[0, 1, 2]).item() * batch_size
		mae_ab += torch.mean(torch.abs(transformed_src - target), dim=[0, 1, 2]).item() * batch_size

		mse_ba += torch.mean((transformed_target - src) ** 2, dim=[0, 1, 2]).item() * batch_size
		mae_ba += torch.mean(torch.abs(transformed_target - src), dim=[0, 1, 2]).item() * batch_size

	rotations_ab = np.concatenate(rotations_ab, axis=0)
	translations_ab = np.concatenate(translations_ab, axis=0)
	rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
	translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

	rotations_ba = np.concatenate(rotations_ba, axis=0)
	translations_ba = np.concatenate(translations_ba, axis=0)
	rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
	translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

	eulers_ab = np.concatenate(eulers_ab, axis=0)
	eulers_ba = np.concatenate(eulers_ba, axis=0)

	return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
		   mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
		   mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
		   translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
		   translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba
Beispiel #29
0
def test_one_epoch(args, net, test_loader):
    net.eval()
    mse_ab = 0
    mae_ab = 0
    mse_ba = 0
    mae_ba = 0

    total_loss = 0
    total_samplenet_loss = 0
    total_cycle_loss = 0
    num_examples = 0
    rotations_ab = []
    translations_ab = []
    rotations_ab_pred = []
    translations_ab_pred = []

    rotations_ba = []
    translations_ba = []
    rotations_ba_pred = []
    translations_ba_pred = []

    eulers_ab = []
    eulers_ba = []

    for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba in tqdm(
            test_loader):
        src = src.cuda()
        target = target.cuda()
        rotation_ab = rotation_ab.cuda()
        translation_ab = translation_ab.cuda()
        rotation_ba = rotation_ba.cuda()
        translation_ba = translation_ba.cuda()

        batch_size = src.size(0)
        num_examples += batch_size

        # Sample points
        samplenet_loss = torch.tensor([0.0]).cuda()
        sampler = net.sampler
        if sampler is not None:
            src_simplified, src_sampled = sampler(src)
            target_simplifed, target_sampled = sampler(target)
            if isinstance(sampler, SampleNet):
                samplenet_loss = 0.5 * sampler.alpha * sampler.get_simplification_loss(
                    src, src_simplified, 0)
                samplenet_loss += 0.5 * sampler.alpha * sampler.get_simplification_loss(
                    target, target_simplifed, 0)
                samplenet_loss += sampler.lmbda * sampler.get_projection_loss()
        else:
            src_sampled = src
            target_sampled = target

        with torch.no_grad():
            rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(
                src_sampled, target_sampled)

        ## save rotation and translation
        rotations_ab.append(rotation_ab.detach().cpu().numpy())
        translations_ab.append(translation_ab.detach().cpu().numpy())
        rotations_ab_pred.append(rotation_ab_pred.detach().cpu().numpy())
        translations_ab_pred.append(translation_ab_pred.detach().cpu().numpy())
        eulers_ab.append(euler_ab.numpy())
        ##
        rotations_ba.append(rotation_ba.detach().cpu().numpy())
        translations_ba.append(translation_ba.detach().cpu().numpy())
        rotations_ba_pred.append(rotation_ba_pred.detach().cpu().numpy())
        translations_ba_pred.append(translation_ba_pred.detach().cpu().numpy())
        eulers_ba.append(euler_ba.numpy())

        transformed_src = transform_point_cloud(src, rotation_ab_pred,
                                                translation_ab_pred)

        transformed_target = transform_point_cloud(target, rotation_ba_pred,
                                                   translation_ba_pred)

        ###########################
        identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
        loss = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
               + F.mse_loss(translation_ab_pred, translation_ab)
        if args.cycle:
            rotation_loss = F.mse_loss(
                torch.matmul(rotation_ba_pred, rotation_ab_pred),
                identity.clone())
            translation_loss = torch.mean(
                (torch.matmul(rotation_ba_pred.transpose(2, 1),
                              translation_ab_pred.view(batch_size, 3, 1)).view(
                                  batch_size, 3) + translation_ba_pred)**2,
                dim=[0, 1])
            cycle_loss = rotation_loss + translation_loss

            loss = loss + cycle_loss * 0.1

        loss = loss + samplenet_loss
        total_loss += loss.item() * batch_size
        total_samplenet_loss += samplenet_loss.item() * batch_size

        if args.cycle:
            total_cycle_loss = total_cycle_loss + cycle_loss.item(
            ) * 0.1 * batch_size

        mse_ab += torch.mean(
            (transformed_src - target)**2, dim=[0, 1, 2]).item() * batch_size
        mae_ab += torch.mean(torch.abs(transformed_src - target),
                             dim=[0, 1, 2]).item() * batch_size

        mse_ba += torch.mean(
            (transformed_target - src)**2, dim=[0, 1, 2]).item() * batch_size
        mae_ba += torch.mean(torch.abs(transformed_target - src),
                             dim=[0, 1, 2]).item() * batch_size

    rotations_ab = np.concatenate(rotations_ab, axis=0)
    translations_ab = np.concatenate(translations_ab, axis=0)
    rotations_ab_pred = np.concatenate(rotations_ab_pred, axis=0)
    translations_ab_pred = np.concatenate(translations_ab_pred, axis=0)

    rotations_ba = np.concatenate(rotations_ba, axis=0)
    translations_ba = np.concatenate(translations_ba, axis=0)
    rotations_ba_pred = np.concatenate(rotations_ba_pred, axis=0)
    translations_ba_pred = np.concatenate(translations_ba_pred, axis=0)

    eulers_ab = np.concatenate(eulers_ab, axis=0)
    eulers_ba = np.concatenate(eulers_ba, axis=0)

    return total_loss * 1.0 / num_examples, total_cycle_loss / num_examples, \
           mse_ab * 1.0 / num_examples, mae_ab * 1.0 / num_examples, \
           mse_ba * 1.0 / num_examples, mae_ba * 1.0 / num_examples, rotations_ab, \
           translations_ab, rotations_ab_pred, translations_ab_pred, rotations_ba, \
           translations_ba, rotations_ba_pred, translations_ba_pred, eulers_ab, eulers_ba, total_samplenet_loss * 1.0 / num_examples