Esempio n. 1
0
    def evaluate():
        mean_source_acc, mean_target_acc, mean_match_acc, process = 0, 0, 0, 0
        with torch.no_grad():
            for i in range(len(modelnet_test)):
                source_pts, source_normal, source_pre_encode, target_pts, target_normal, target_pre_encode, match, raw_pts, T, idx = modelnet_test[
                    i]
                source_pts_norm, target_pts_norm = pc_normalize(
                    deepcopy(source_pts)), pc_normalize(deepcopy(target_pts))
                source_inp = torch.Tensor(
                    np.concatenate(
                        [source_pts_norm, source_pts_norm, source_pre_encode],
                        axis=1)).to(device)
                target_inp = torch.Tensor(
                    np.concatenate(
                        [target_pts_norm, target_pts_norm, target_pre_encode],
                        axis=1)).to(device)
                match_label = torch.Tensor(match).to(device)
                # 为circle loss做准备
                source_pc = o3d.PointCloud()
                source_pc.points = o3d.Vector3dVector(source_pts)
                coords_dist = utils.square_distance(
                    torch.Tensor(
                        np.asarray(deepcopy(source_pc).transform(
                            T).points)).unsqueeze(0),
                    torch.Tensor(target_pts).unsqueeze(0))[0].to(device)
                coords_dist = torch.sqrt(coords_dist)
                # print(((coords_dist < 0.04).sum(1) > 0).sum(), (match_label.sum(1) > 0).sum())
                loss, source_acc, target_acc, match_acc = loss_fn(
                    net, source_inp, target_inp, match_label, coords_dist)

                process += 1
                mean_source_acc += source_acc
                mean_target_acc += target_acc
                mean_match_acc += match_acc
                print(
                    "\rtest process: %s   loss: %.5f   source overlap acc: %.5f  target overlap acc: %.5f  match acc: %.5f"
                    % (processbar(process, len(modelnet_test)), loss.item(),
                       source_acc, target_acc, match_acc),
                    end="")
        mean_source_acc /= len(modelnet_test)
        mean_target_acc /= len(modelnet_test)
        mean_match_acc /= len(modelnet_test)
        print(
            "\ntest finish  mean source overlap acc: %.5f  mean target overlap acc: %.5f  mean match acc: %.5f"
            % (mean_source_acc, mean_target_acc, mean_match_acc))
        return mean_match_acc
Esempio n. 2
0
 def __getitem__(self, item):
     file = self.files[item]
     ref_cloud = readpcd(file, rtype='npy')
     ref_cloud = random_select_points(ref_cloud, m=self.npts)
     ref_cloud = pc_normalize(ref_cloud)
     R, t = generate_random_rotation_matrix(-20, 20), \
            generate_random_tranlation_vector(-0.5, 0.5)
     src_cloud = transform(ref_cloud, R, t)
     if self.train:
         ref_cloud = jitter_point_cloud(ref_cloud)
         src_cloud = jitter_point_cloud(src_cloud)
     return ref_cloud, src_cloud, R, t
Esempio n. 3
0
 def __getitem__(self, item):
     if item in self.caches:
         return self.caches[item]
     file = self.files[item]
     ref_cloud = readpcd(file, rtype='npy')
     ref_cloud = random_select_points(ref_cloud, m=self.npts)
     ref_cloud = pc_normalize(ref_cloud)
     #if self.train:
     #    ref_cloud = shift_point_cloud(ref_cloud)
     R, t = self.Rs[item], self.ts[item]
     src_cloud = transform(ref_cloud, R, t)
     if self.train:
         ref_cloud = jitter_point_cloud(ref_cloud)
         src_cloud = jitter_point_cloud(src_cloud)
     self.caches[item] = [ref_cloud, src_cloud, R, t]
     return ref_cloud, src_cloud, R, t
Esempio n. 4
0
def test():
    modelnet_set = Modelnet40Pair(path,
                                  mode="test",
                                  descriptor="BS",
                                  categories=categories,
                                  corr_pts_dis_thresh=corr_pts_dis_thresh,
                                  get_np=True,
                                  support_r=support_r)
    print(len(modelnet_set))
    net.load_state_dict(torch.load(param_save_path))
    net.eval()
    test_charmfer_dis = 0
    with torch.no_grad():
        for i in range(len(modelnet_set)):
            source_pts, source_normal, source_f, source_key_pts_idx, \
            target_pts, target_normal, target_f, target_key_pts_idx, \
            source_to_target_min_idx, target_to_source_min_idx, raw_pts, T = modelnet_set[i]

            source_pc = get_pc(source_pts, source_normal, [1, 0.706, 0])
            target_pc = get_pc(target_pts, target_normal, [0, 0.651, 0.929])

            source_inp = torch.cat([
                torch.Tensor(pc_normalize(deepcopy(source_pts))),
                torch.Tensor(source_normal),
                torch.Tensor(source_f)
            ],
                                   dim=1).to(device)
            target_inp = torch.cat([
                torch.Tensor(pc_normalize(deepcopy(target_pts))),
                torch.Tensor(target_normal),
                torch.Tensor(target_f)
            ],
                                   dim=1).to(device)
            source_to_target_min_idx, target_to_source_min_idx = torch.LongTensor(
                source_to_target_min_idx).unsqueeze(0), torch.LongTensor(
                    target_to_source_min_idx).unsqueeze(0)

            source_key_sorce, target_key_sorce = net(source_inp.unsqueeze(0),
                                                     target_inp.unsqueeze(0),
                                                     source_to_target_min_idx,
                                                     target_to_source_min_idx)
            source_key_sorce, target_key_sorce = source_key_sorce[0].cpu(
            ).numpy(), target_key_sorce[0].cpu().numpy()
            sorce_thresh = 0.3
            source_key_pts_idx_pred, target_key_pts_idx_pred = np.nonzero(
                source_key_sorce >= sorce_thresh)[0], np.nonzero(
                    target_key_sorce >= sorce_thresh)[0]
            # 关键部分配准
            ransac_T = ransac_pose_estimation(
                source_pts[source_key_pts_idx_pred],
                target_pts[target_key_pts_idx_pred],
                source_f[source_key_pts_idx_pred],
                target_f[target_key_pts_idx_pred],
                distance_threshold=0.06,
                max_iter=500000,
                max_valid=100000)
            icp_result = o3d.registration_icp(source_pc,
                                              target_pc,
                                              0.06,
                                              init=ransac_T)
            icp_T = icp_result.transformation

            # 评估
            chamfer_dist = chamfer_distance(source_pts, target_pts, raw_pts,
                                            icp_T, T)
            print(chamfer_dist)
            test_charmfer_dis += chamfer_dist.item()
            source_pc_key = get_pc(source_pts, None, [1, 0.706, 0])
            target_pc_key = get_pc(target_pts, None, [0, 0.651, 0.929])
            source_pc_key_color, target_pc_key_color = np.asarray(
                source_pc_key.colors), np.asarray(target_pc_key.colors)
            source_pc_key_color[source_key_pts_idx_pred] = np.array([1, 0, 0])
            target_pc_key_color[target_key_pts_idx_pred] = np.array([1, 0, 0])

            o3d.draw_geometries([target_pc_key],
                                window_name="test key pts",
                                width=1000,
                                height=800)
            o3d.draw_geometries([source_pc_key],
                                window_name="test key pts",
                                width=1000,
                                height=800)
            o3d.draw_geometries([source_pc.transform(icp_T), target_pc],
                                window_name="test registration",
                                width=1000,
                                height=800)
    print("test charmfer: %.5f" % (test_charmfer_dis / len(modelnet_set)))
Esempio n. 5
0
    def __getitem__(self, index):
        x = self.data[index]
        source_pts, source_normal = x["points_src"][:, :3], x["points_src"][:,
                                                                            3:]
        target_pts, target_normal = x["points_ref"][:, :3], x["points_ref"][:,
                                                                            3:]
        raw_pts = x["points_raw"][:, :3]
        idx = x["idx"]
        T = np.concatenate(
            [x["transform_gt"], np.array([[0, 0, 0, 1]])], axis=0)
        source_pc = get_pc(source_pts, source_normal, [1, 0.706, 0])
        target_pc = get_pc(target_pts, target_normal, [0, 0.651, 0.929])

        # source_f, target_f = self.descriptor.get(source_pc, self.support_r), self.descriptor.get(target_pc, self.support_r)
        source_f, target_f = self.descriptor.get(source_pc, self.support_r,
                                                 1 / 9), self.descriptor.get(
                                                     target_pc, self.support_r,
                                                     1 / 9)
        f_dis = square_distance(
            torch.Tensor(source_f).unsqueeze(0),
            torch.Tensor(target_f).unsqueeze(0))[0]
        source_to_target_min_idx, target_to_source_min_idx = f_dis.min(
            dim=1)[1].numpy(), f_dis.min(dim=0)[1].numpy()

        rotated_source_pts = np.asarray(
            deepcopy(source_pc).transform(T).points)
        source_corr_pts, target_corr_pts = target_pts[
            source_to_target_min_idx], rotated_source_pts[
                target_to_source_min_idx]
        source_pts_to_target_pts_dis = np.sqrt(
            np.sum((rotated_source_pts - source_corr_pts)**2, axis=1))
        target_pts_to_source_pts_dis = np.sqrt(
            np.sum((target_pts - target_corr_pts)**2, axis=1))

        dis_thresh = self.dis_thresh
        source_key_pts_idx, target_key_pts_idx = np.nonzero(
            source_pts_to_target_pts_dis <= dis_thresh)[0], np.nonzero(
                target_pts_to_source_pts_dis <= dis_thresh)[0]
        source_key_label, target_key_label = np.zeros(
            (source_pts.shape[0], )), np.zeros((target_pts.shape[0], ))
        source_key_label[source_key_pts_idx], target_key_label[
            target_key_pts_idx] = 1, 1

        # 画出关键点看看
        source_pc = get_pc(source_pts, None, [1, 0.706, 0])
        target_pc = get_pc(target_pts, None, [0, 0.651, 0.929])
        source_color, target_color = np.asarray(source_pc.colors), np.asarray(
            target_pc.colors)
        source_color[source_key_pts_idx] = np.array([1, 0, 0])
        target_color[target_key_pts_idx] = np.array([1, 0, 0])
        # o3d.draw_geometries([source_pc, target_pc], window_name="test", width=1000, height=800)

        if self.get_np:
            return source_pts, source_normal, source_f, source_key_pts_idx, \
                   target_pts, target_normal, target_f, target_key_pts_idx, \
                   source_to_target_min_idx, target_to_source_min_idx, raw_pts, T
        source_pts, target_pts = pc_normalize(source_pts), pc_normalize(
            target_pts)
        if self.get_raw_and_T:
            return torch.Tensor(source_pts), torch.Tensor(source_normal), torch.Tensor(source_f), torch.Tensor(source_key_label), \
                   torch.Tensor(target_pts), torch.Tensor(target_normal), torch.Tensor(target_f), torch.Tensor(target_key_label), \
                   torch.Tensor(raw_pts), torch.Tensor(T)
        else:
            return torch.Tensor(source_pts), torch.Tensor(source_normal), torch.Tensor(source_f), torch.Tensor(source_key_label), \
                   torch.Tensor(target_pts), torch.Tensor(target_normal), torch.Tensor(target_f), torch.Tensor(target_key_label), \
                   torch.LongTensor(source_to_target_min_idx), torch.LongTensor(target_to_source_min_idx)
Esempio n. 6
0
def test_crop():
    modelnet_test = Modelnet40Pair(path=dataset_path,
                                   r_lrf=r_lrf,
                                   theta_space=theta_space,
                                   mode="test",
                                   pre_encode_path=pre_encode_path,
                                   noise_type="crop")
    net = SAO()
    net.to(device)
    sao_param_path = "./params/sao-modelnet-6space-circle-in.pth"
    net.load_state_dict(torch.load(sao_param_path))
    net.eval()
    chamfer_test, test_cnt = 0, 0
    valid_idx = list(np.load("./test_idx.npy"))
    with torch.no_grad():
        for i in range(len(modelnet_test)):
            source_pts, source_normal, source_pre_encode, target_pts, target_normal, target_pre_encode, match, raw_pts, T, idx = modelnet_test[
                i]
            if idx not in valid_idx:
                continue
            test_cnt += 1
            source_pts_norm, target_pts_norm = pc_normalize(
                deepcopy(source_pts)), pc_normalize(deepcopy(target_pts))
            source_inp = torch.Tensor(
                np.concatenate(
                    [source_pts_norm, source_pts_norm, source_pre_encode],
                    axis=1)).to(device)
            target_inp = torch.Tensor(
                np.concatenate(
                    [target_pts_norm, source_pts_norm, target_pre_encode],
                    axis=1)).to(device)
            # network predict
            source_f, target_f, source_overlap, target_overlap = net(
                source_inp, target_inp)
            source_f, target_f, source_overlap, target_overlap = source_f.cpu(
            ).numpy(), target_f.cpu().numpy(), source_overlap.cpu().numpy(
            ), target_overlap.cpu().numpy()
            thresh = 0.4
            source_overlap_idx, target_overlap_idx = np.nonzero(
                source_overlap >= thresh)[0], np.nonzero(
                    target_overlap >= thresh)[0]

            # # 画出预测的重叠部分,看看准不准
            # source_pc = get_pc(source_pts, source_normal, [1, 0.706, 0])
            # target_pc = get_pc(target_pts, target_normal, [0, 0.651, 0.929])
            # source_colors, target_colors = np.asarray(source_pc.colors), np.asarray(target_pc.colors)
            # source_colors[source_overlap_idx, :], target_colors[target_overlap_idx, :] = np.array([[1, 0, 0]] * source_overlap_idx.shape[0]), np.asarray([[0, 0, 1]] * target_overlap_idx.shape[0])
            # o3d.draw_geometries([deepcopy(source_pc).transform(T), target_pc], window_name="test", width=1000, height=800)

            source_pc = get_pc(source_pts, source_normal, [1, 0.706, 0])
            target_pc = get_pc(target_pts, target_normal, [0, 0.651, 0.929])
            # 重叠部分配准
            source_overlap_pc, target_overlap_pc = o3d.PointCloud(
            ), o3d.PointCloud()
            source_overlap_pc.points, target_overlap_pc.points = o3d.Vector3dVector(
                np.asarray(
                    source_pc.points)[source_overlap_idx]), o3d.Vector3dVector(
                        np.asarray(target_pc.points)[target_overlap_idx])

            ransac_T = ransac_pose_estimation(source_pts[source_overlap_idx],
                                              target_pts[target_overlap_idx],
                                              source_f[source_overlap_idx],
                                              target_f[target_overlap_idx],
                                              distance_threshold=0.07)
            icp_result = o3d.registration_icp(source_overlap_pc,
                                              target_overlap_pc,
                                              0.06,
                                              init=ransac_T)
            icp_T = icp_result.transformation

            # 评估
            chamfer_dist = chamfer_distance(source_pts, target_pts, raw_pts,
                                            icp_T, T)
            chamfer_test += chamfer_dist.item()

            print(
                "\rtest process: %s  cur chamfer dis: %.5f   item dis: %.5f" %
                (processbar(test_cnt, len(valid_idx)), chamfer_test / test_cnt,
                 chamfer_dist.item()),
                end="")
            # o3d.draw_geometries([source_pc, target_pc], window_name="before registration", width=1000, height=800)
            # o3d.draw_geometries([source_pc.transform(icp_T), target_pc], window_name="registration", width=1000, height=800)
    print("\ntest finish, charmer dis: %.5f" % (chamfer_test / len(valid_idx)))
Esempio n. 7
0
def train():
    epoch = 251
    lr = 0.001
    min_lr = 0.00001
    lr_update_step = 20
    loss_fn = SAOLoss()
    net = SAO()
    net.to(device)
    sao_param_path = "./params/sao-modelnet-6space-circle-in.pth"
    net.load_state_dict(torch.load(sao_param_path))
    # optimizer = torch.optim.SGD(params=net.parameters(), lr=lr, weight_decay=0)
    optimizer = torch.optim.Adam(params=net.parameters(),
                                 lr=lr,
                                 weight_decay=0)

    def update_lr(optimizer, gamma=0.5):
        lr = 0
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
        lr = max(lr * gamma, min_lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print("lr update finished  cur lr: %.5f" % lr)

    def evaluate():
        mean_source_acc, mean_target_acc, mean_match_acc, process = 0, 0, 0, 0
        with torch.no_grad():
            for i in range(len(modelnet_test)):
                source_pts, source_normal, source_pre_encode, target_pts, target_normal, target_pre_encode, match, raw_pts, T, idx = modelnet_test[
                    i]
                source_pts_norm, target_pts_norm = pc_normalize(
                    deepcopy(source_pts)), pc_normalize(deepcopy(target_pts))
                source_inp = torch.Tensor(
                    np.concatenate(
                        [source_pts_norm, source_pts_norm, source_pre_encode],
                        axis=1)).to(device)
                target_inp = torch.Tensor(
                    np.concatenate(
                        [target_pts_norm, target_pts_norm, target_pre_encode],
                        axis=1)).to(device)
                match_label = torch.Tensor(match).to(device)
                # 为circle loss做准备
                source_pc = o3d.PointCloud()
                source_pc.points = o3d.Vector3dVector(source_pts)
                coords_dist = utils.square_distance(
                    torch.Tensor(
                        np.asarray(deepcopy(source_pc).transform(
                            T).points)).unsqueeze(0),
                    torch.Tensor(target_pts).unsqueeze(0))[0].to(device)
                coords_dist = torch.sqrt(coords_dist)
                # print(((coords_dist < 0.04).sum(1) > 0).sum(), (match_label.sum(1) > 0).sum())
                loss, source_acc, target_acc, match_acc = loss_fn(
                    net, source_inp, target_inp, match_label, coords_dist)

                process += 1
                mean_source_acc += source_acc
                mean_target_acc += target_acc
                mean_match_acc += match_acc
                print(
                    "\rtest process: %s   loss: %.5f   source overlap acc: %.5f  target overlap acc: %.5f  match acc: %.5f"
                    % (processbar(process, len(modelnet_test)), loss.item(),
                       source_acc, target_acc, match_acc),
                    end="")
        mean_source_acc /= len(modelnet_test)
        mean_target_acc /= len(modelnet_test)
        mean_match_acc /= len(modelnet_test)
        print(
            "\ntest finish  mean source overlap acc: %.5f  mean target overlap acc: %.5f  mean match acc: %.5f"
            % (mean_source_acc, mean_target_acc, mean_match_acc))
        return mean_match_acc

    max_acc = 0
    for epoch_count in range(1, epoch + 1):
        mean_source_acc, mean_target_acc, mean_match_acc, process = 0, 0, 0, 0
        loss_val = 0
        rand_idx = np.random.permutation(len(modelnet_train))
        for i in range(len(modelnet_train)):
            source_pts, source_normal, source_pre_encode, target_pts, target_normal, target_pre_encode, match, raw_pts, T, idx = modelnet_train[
                rand_idx[i]]
            source_pts_norm, target_pts_norm = pc_normalize(
                deepcopy(source_pts)), pc_normalize(deepcopy(target_pts))
            source_inp = torch.Tensor(
                np.concatenate(
                    [source_pts_norm, source_pts_norm, source_pre_encode],
                    axis=1)).to(device)
            target_inp = torch.Tensor(
                np.concatenate(
                    [target_pts_norm, target_pts_norm, target_pre_encode],
                    axis=1)).to(device)
            match_label = torch.Tensor(match).to(device)
            # 为circle loss做准备
            source_pc = o3d.PointCloud()
            source_pc.points = o3d.Vector3dVector(source_pts)
            coords_dist = utils.square_distance(
                torch.Tensor(
                    np.asarray(
                        deepcopy(source_pc).transform(T).points)).unsqueeze(0),
                torch.Tensor(target_pts).unsqueeze(0))[0].to(device)
            coords_dist = torch.sqrt(coords_dist)
            # print(((coords_dist < 0.04).sum(1) > 0).sum(), (match_label.sum(1) > 0).sum())
            loss, source_acc, target_acc, match_acc = loss_fn(
                net, source_inp, target_inp, match_label, coords_dist)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_val += loss.item()
            process += 1
            mean_source_acc += source_acc
            mean_target_acc += target_acc
            mean_match_acc += match_acc
            print(
                "\rprocess: %s   loss: %.5f   source overlap acc: %.5f  target overlap acc: %.5f  match acc: %.5f"
                % (processbar(process, len(modelnet_train)), loss.item(),
                   source_acc, target_acc, match_acc),
                end="")
        mean_source_acc /= len(modelnet_train)
        mean_target_acc /= len(modelnet_train)
        mean_match_acc /= len(modelnet_train)
        print(
            "\nepoch: %d  loss: %.5f  mean source overlap acc: %.5f  mean target overlap acc: %.5f  mean match acc: %.5f"
            % (epoch_count, loss_val, mean_source_acc, mean_target_acc,
               mean_match_acc))
        test_match_acc = evaluate()
        if max_acc < test_match_acc:
            max_acc = test_match_acc
            print("save ....")
            torch.save(net.state_dict(), sao_param_path)
            print("finish !!!!!!")
        if epoch_count % lr_update_step == 0:
            update_lr(optimizer, 0.5)