Exemple #1
0
def data_process(args, model, data, label):

    if args.task == '1obj_rotate':
        data1, data2, label1, label2 = obj_rotate_perm(data, label,
                                                       args.cuda)  # (B, N, 3)
    elif args.task == '2obj':
        data1, data2, label1, label2 = test_obj_2_perm(data, label,
                                                       args.cuda)  # (B, N, 3)
    else:
        print('Task not implemented!')
        exit(0)

    if args.mixup == 'emd':
        mixup_data = emd_mixup(data1, data2)  # (B, N, 3)
    elif args.mixup == 'add':
        mixup_data = add_mixup(data1, data2, args.cuda)  # (B, N, 3)

    mixup_data = mixup_data.permute(0, 2, 1)  # (B, 3, N)
    batch_size = mixup_data.size()[0]

    # torch.set_printoptions(profile="full")
    print(label[0])

    if args.use_one_hot:
        label_one_hot1 = np.zeros((batch_size, 16))
        label_one_hot2 = np.zeros((batch_size, 16))
        for idx in range(batch_size):
            label_one_hot1[idx, label1[idx]] = 1
            label_one_hot2[idx, label2[idx]] = 1

        label_one_hot1 = torch.from_numpy(label_one_hot1.astype(np.float32))
        label_one_hot2 = torch.from_numpy(label_one_hot2.astype(np.float32))
    else:
        label_one_hot1 = torch.rand(batch_size, 16)
        label_one_hot2 = torch.rand(batch_size, 16)

    device = torch.device('cuda') if args.cuda else torch.device('cpu')
    data, label_one_hot1, label_one_hot2 = data.to(device), label_one_hot1.to(
        device), label_one_hot2.to(device)

    pred1 = model(mixup_data, rand_proj(data1), label_one_hot1)
    pred2 = model(mixup_data, rand_proj(data2), label_one_hot2)

    mixup_data = mixup_data.permute(0, 2, 1)
    pred1, pred2 = pred1.permute(0, 2, 1), pred2.permute(0, 2, 1)

    print('diff of a and b',
          chamfer_distance(pred1, pred2) + chamfer_distance(pred2, pred1))
    print('loss for a and a',
          chamfer_distance(data1, pred1) + chamfer_distance(pred1, data1))
    print('loss for b and b',
          chamfer_distance(data2, pred2) + chamfer_distance(pred2, data2))

    return data1, data2, mixup_data, pred1, pred2
    def validate(self, data, writer):
        total_loss = 0
        # local losses at different distances from the touch sites

        self.encoder.eval()
        all_losses = []
        for v, valid_loader in enumerate(data):
            num_examples = 0
            class_loss = 0
            for k, batch in enumerate(tqdm(valid_loader)):
                # initialize data
                img_occ = batch['img_occ'].cuda()
                img_unocc = batch['img_unocc'].cuda()
                gt_points = batch['gt_points'].cuda()
                batch_size = img_occ.shape[0]
                obj_class = batch['class'][0]

                # model prediction
                verts = self.encoder(img_occ, img_unocc, batch)

                # losses
                loss = utils.chamfer_distance(verts,
                                              self.adj_info['faces'],
                                              gt_points,
                                              num=self.num_samples)
                all_losses += [l.item() for l in loss * self.args.loss_coeff]

                loss = self.args.loss_coeff * loss.mean() * batch_size

                # logs
                num_examples += float(batch_size)
                class_loss += loss

            print_loss = (class_loss / num_examples)
            message = f'Valid || Epoch: {self.epoch}, class: {obj_class}, f1: {print_loss:.2f}'
            tqdm.write(message)
            total_loss += (print_loss / float(len(self.classes)))

        print('*******************************************************')
        print(f'Validation Accuracy: {total_loss}')
        print('*******************************************************')

        writer.add_scalars('valid_ptp', {self.args.exp_id: total_loss},
                           self.epoch)
        self.current_loss = total_loss
    def train(self, data, writer):

        total_loss = 0
        iterations = 0
        self.encoder.train()
        for k, batch in enumerate(tqdm(data)):
            self.optimizer.zero_grad()

            # initialize data
            img_occ = batch['img_occ'].cuda()
            img_unocc = batch['img_unocc'].cuda()
            gt_points = batch['gt_points'].cuda()

            # inference
            # self.encoder.img_encoder.pooling(img_unocc, gt_points, debug=True)
            verts = self.encoder(img_occ, img_unocc, batch)

            # losses
            loss = utils.chamfer_distance(verts,
                                          self.adj_info['faces'],
                                          gt_points,
                                          num=self.num_samples)
            loss = self.args.loss_coeff * loss.mean()

            # backprop
            loss.backward()
            self.optimizer.step()

            # log
            message = f'Train || Epoch: {self.epoch}, loss: {loss.item():.2f}, b_ptp:  {self.best_loss:.2f}'
            tqdm.write(message)
            total_loss += loss.item()
            iterations += 1.

        writer.add_scalars('train_loss',
                           {self.args.exp_id: total_loss / iterations},
                           self.epoch)
Exemple #4
0
def test():
    modelnet_set = Modelnet40Pair(path,
                                  mode="test",
                                  descriptor="BS",
                                  categories=categories,
                                  corr_pts_dis_thresh=corr_pts_dis_thresh,
                                  get_np=True,
                                  support_r=support_r)
    print(len(modelnet_set))
    net.load_state_dict(torch.load(param_save_path))
    net.eval()
    test_charmfer_dis = 0
    with torch.no_grad():
        for i in range(len(modelnet_set)):
            source_pts, source_normal, source_f, source_key_pts_idx, \
            target_pts, target_normal, target_f, target_key_pts_idx, \
            source_to_target_min_idx, target_to_source_min_idx, raw_pts, T = modelnet_set[i]

            source_pc = get_pc(source_pts, source_normal, [1, 0.706, 0])
            target_pc = get_pc(target_pts, target_normal, [0, 0.651, 0.929])

            source_inp = torch.cat([
                torch.Tensor(pc_normalize(deepcopy(source_pts))),
                torch.Tensor(source_normal),
                torch.Tensor(source_f)
            ],
                                   dim=1).to(device)
            target_inp = torch.cat([
                torch.Tensor(pc_normalize(deepcopy(target_pts))),
                torch.Tensor(target_normal),
                torch.Tensor(target_f)
            ],
                                   dim=1).to(device)
            source_to_target_min_idx, target_to_source_min_idx = torch.LongTensor(
                source_to_target_min_idx).unsqueeze(0), torch.LongTensor(
                    target_to_source_min_idx).unsqueeze(0)

            source_key_sorce, target_key_sorce = net(source_inp.unsqueeze(0),
                                                     target_inp.unsqueeze(0),
                                                     source_to_target_min_idx,
                                                     target_to_source_min_idx)
            source_key_sorce, target_key_sorce = source_key_sorce[0].cpu(
            ).numpy(), target_key_sorce[0].cpu().numpy()
            sorce_thresh = 0.3
            source_key_pts_idx_pred, target_key_pts_idx_pred = np.nonzero(
                source_key_sorce >= sorce_thresh)[0], np.nonzero(
                    target_key_sorce >= sorce_thresh)[0]
            # 关键部分配准
            ransac_T = ransac_pose_estimation(
                source_pts[source_key_pts_idx_pred],
                target_pts[target_key_pts_idx_pred],
                source_f[source_key_pts_idx_pred],
                target_f[target_key_pts_idx_pred],
                distance_threshold=0.06,
                max_iter=500000,
                max_valid=100000)
            icp_result = o3d.registration_icp(source_pc,
                                              target_pc,
                                              0.06,
                                              init=ransac_T)
            icp_T = icp_result.transformation

            # 评估
            chamfer_dist = chamfer_distance(source_pts, target_pts, raw_pts,
                                            icp_T, T)
            print(chamfer_dist)
            test_charmfer_dis += chamfer_dist.item()
            source_pc_key = get_pc(source_pts, None, [1, 0.706, 0])
            target_pc_key = get_pc(target_pts, None, [0, 0.651, 0.929])
            source_pc_key_color, target_pc_key_color = np.asarray(
                source_pc_key.colors), np.asarray(target_pc_key.colors)
            source_pc_key_color[source_key_pts_idx_pred] = np.array([1, 0, 0])
            target_pc_key_color[target_key_pts_idx_pred] = np.array([1, 0, 0])

            o3d.draw_geometries([target_pc_key],
                                window_name="test key pts",
                                width=1000,
                                height=800)
            o3d.draw_geometries([source_pc_key],
                                window_name="test key pts",
                                width=1000,
                                height=800)
            o3d.draw_geometries([source_pc.transform(icp_T), target_pc],
                                window_name="test registration",
                                width=1000,
                                height=800)
    print("test charmfer: %.5f" % (test_charmfer_dis / len(modelnet_set)))
Exemple #5
0
def train(args, configpath):
    io = init(args, configpath)
    train_dataset = ShapeNetPart(partition='trainval',
                                 num_points=args.num_points)
    if (len(train_dataset) < 100):
        drop_last = False
    else:
        drop_last = True
    train_loader = DataLoader(train_dataset,
                              num_workers=8,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=drop_last)
    test_loader = DataLoader(ShapeNetPart(partition='test',
                                          num_points=args.num_points),
                             num_workers=8,
                             batch_size=args.test_batch_size,
                             shuffle=True,
                             drop_last=False)

    seg_num_all = train_loader.dataset.seg_num_all
    seg_start_index = train_loader.dataset.seg_start_index

    device = torch.device("cuda" if args.cuda else "cpu")

    if args.model == 'consnet':
        model = ConsNet(args, seg_num_all).to(device)
    elif args.model == 'pretrain':
        model = ConsNet(args, seg_num_all).to(device)
        model.load_state_dict(torch.load(args.pretrain_path))
    else:
        raise Exception("Not implemented")

    if args.parallel == True:
        model = nn.DataParallel(model)

    print(str(model))

    if args.use_sgd:
        print("Use SGD")
        opt = optim.SGD(model.parameters(),
                        lr=args.lr * 100,
                        momentum=args.momentum,
                        weight_decay=1e-4)
        cur_lr = args.lr * 100
    else:
        print("Use Adam")
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
        cur_lr = args.lr

    if args.scheduler == 'cos':
        scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=1e-3)
        print('Use COS')
    elif args.scheduler == 'step':
        scheduler = StepLR(opt, step_size=20, gamma=0.7)
        print('Use Step')

    if args.loss == 'l1loss':
        print('Use L1 Loss')
    elif args.loss == 'chamfer':
        print('Use Chamfer Distance')
    else:
        print('Not implemented')

    io.cprint('Experiment: %s' % args.exp_name)

    # Train
    min_loss = 100
    io.cprint('Begin to train...')
    for epoch in range(args.epochs):
        io.cprint(
            '=====================================Epoch %d========================================'
            % epoch)
        io.cprint('*****Train*****')
        # Train
        model.train()
        train_loss = 0
        for i, point in enumerate(train_loader):
            data, label, seg = point
            if epoch < 5:
                lr = 0.18 * cur_lr * epoch + 0.1 * cur_lr
                adjust_learning_rate(opt, lr)

            if args.task == '1obj_rotate':
                data1, data2, label1, label2 = obj_rotate_perm(
                    data, label)  # (B, N, 3)
            elif args.task == '2obj':
                data1, data2, label1, label2 = obj_2_perm(data,
                                                          label)  # (B, N, 3)
            elif args.task == 'alter':
                if epoch % 2 == 0:
                    data1, data2, label1, label2 = obj_rotate_perm(
                        data, label)  # (B, N, 3)
                else:
                    data1, data2, label1, label2 = obj_2_perm(
                        data, label)  # (B, N, 3)
            else:
                print('Task not implemented!')
                exit(0)

            if args.mixup == 'emd':
                mixup_data = emd_mixup(data1, data2)  # (B, N, 3)
            elif args.mixup == 'add':
                mixup_data = add_mixup(data1, data2)  # (B, N, 3)

            mixup_data = mixup_data.permute(0, 2, 1)  # (B, 3, N)
            batch_size = mixup_data.size()[0]

            seg = seg - seg_start_index

            if args.use_one_hot:
                label_one_hot1 = np.zeros((batch_size, 16))
                label_one_hot2 = np.zeros((batch_size, 16))
                for idx in range(batch_size):
                    label_one_hot1[idx, label1[idx]] = 1
                    label_one_hot2[idx, label2[idx]] = 1

                label_one_hot1 = torch.from_numpy(
                    label_one_hot1.astype(np.float32))
                label_one_hot2 = torch.from_numpy(
                    label_one_hot2.astype(np.float32))
            else:
                label_one_hot1 = torch.rand(batch_size, 16)
                label_one_hot2 = torch.rand(batch_size, 16)

            data, label_one_hot1, label_one_hot2, seg = data.to(
                device), label_one_hot1.to(device), label_one_hot2.to(
                    device), seg.to(device)

            # Project
            proj1 = rand_proj(data1)
            proj2 = rand_proj(data2)

            # Train
            opt.zero_grad()

            pred1 = model(mixup_data, proj1,
                          label_one_hot1).permute(0, 2, 1)  # (B, N, 3)
            pred2 = model(mixup_data, proj2,
                          label_one_hot2).permute(0, 2, 1)  # (B, N, 3)

            if args.loss == 'l1loss':
                loss = L1_loss(pred1, data1) + L1_loss(pred2, data2)
            elif args.loss == 'chamfer':
                loss1 = chamfer_distance(pred1, data1) + chamfer_distance(
                    data1, pred1)
                loss2 = chamfer_distance(pred2, data2) + chamfer_distance(
                    data2, pred2)
                loss = loss1 + loss2
            elif args.loss == 'emd':
                loss = emd_loss(pred1, data1) + emd_loss(pred2, data2)
            elif args.loss == 'emd2':
                loss = emd_loss_2(pred1, data1) + emd_loss_2(pred2, data2)
            else:
                raise NotImplementedError

            if args.l2loss:
                l2_loss = nn.MSELoss()(pred1, data1) + nn.MSELoss()(pred2,
                                                                    data2)
                loss += args.l2_param * l2_loss

            loss.backward()

            train_loss = train_loss + loss.item()
            opt.step()

            if (i + 1) % 100 == 0:
                io.cprint('iters %d, tarin loss: %.6f' % (i, train_loss / i))

        io.cprint('Learning rate: %.6f' % (opt.param_groups[0]['lr']))

        if args.scheduler == 'cos':
            scheduler.step()
        elif args.scheduler == 'step':
            if opt.param_groups[0]['lr'] > 1e-5:
                scheduler.step()
            if opt.param_groups[0]['lr'] < 1e-5:
                for param_group in opt.param_groups:
                    param_group['lr'] = 1e-5

        # Test
        if args.valid:
            io.cprint('*****Test*****')
            test_loss = 0
            model.eval()
            for data, label, seg in test_loader:
                with torch.no_grad():
                    if args.task == '1obj_rotate':
                        data1, data2, label1, label2 = obj_rotate_perm(
                            data, label)  # (B, N, 3)
                    elif args.task == '2obj':
                        data1, data2, label1, label2 = obj_2_perm(
                            data, label)  # (B, N, 3)
                    elif args.task == 'alter':
                        if epoch % 2 == 0:
                            data1, data2, label1, label2 = obj_rotate_perm(
                                data, label)  # (B, N, 3)
                        else:
                            data1, data2, label1, label2 = obj_2_perm(
                                data, label)  # (B, N, 3)
                    else:
                        print('Task not implemented!')
                        exit(0)

                    if args.mixup == 'emd':
                        mixup_data = emd_mixup(data1, data2)  # (B, N, 3)
                    elif args.mixup == 'add':
                        mixup_data = add_mixup(data1, data2)  # (B, N, 3)

                    mixup_data = mixup_data.permute(0, 2, 1)  # (B, 3, N)
                    batch_size = mixup_data.size()[0]

                    seg = seg - seg_start_index
                    label_one_hot1 = np.zeros((batch_size, 16))
                    label_one_hot2 = np.zeros((batch_size, 16))
                    for idx in range(batch_size):
                        label_one_hot1[idx, label1[idx]] = 1
                        label_one_hot2[idx, label2[idx]] = 1

                    label_one_hot1 = torch.from_numpy(
                        label_one_hot1.astype(np.float32))
                    label_one_hot2 = torch.from_numpy(
                        label_one_hot2.astype(np.float32))
                    data, label_one_hot1, label_one_hot2, seg = data.to(
                        device), label_one_hot1.to(device), label_one_hot2.to(
                            device), seg.to(device)

                    proj1 = rand_proj(data1)
                    proj2 = rand_proj(data2)

                    pred1 = model(mixup_data, proj1,
                                  label_one_hot1).permute(0, 2, 1)  # (B, N, 3)
                    pred2 = model(mixup_data, proj2,
                                  label_one_hot2).permute(0, 2, 1)  # (B, N, 3)

                    if args.loss == 'l1loss':
                        loss = L1_loss(pred1, data1) + L1_loss(pred2, data2)
                    elif args.loss == 'chamfer':
                        loss1 = chamfer_distance(
                            pred1, data1) + chamfer_distance(data1, pred1)
                        loss2 = chamfer_distance(
                            pred2, data2) + chamfer_distance(data2, pred2)
                        loss = loss1 + loss2
                    elif args.loss == 'emd':
                        loss = emd_loss(pred1, data1) + emd_loss(pred2, data2)
                    elif args.loss == 'emd2':
                        loss = emd_loss_2(pred1, data1) + emd_loss_2(
                            pred2, data2)
                    else:
                        raise NotImplementedError

                    test_loss = test_loss + loss.item()
            io.cprint(
                'Train loss: %.6f, Test loss: %.6f' %
                (train_loss / len(train_loader), test_loss / len(test_loader)))
            cur_loss = test_loss / len(test_loader)
            if cur_loss < min_loss:
                min_loss = cur_loss
                torch.save(
                    model.state_dict(), 'checkpoints/%s/best_%s.pkl' %
                    (args.exp_name, args.exp_name))
        if (epoch + 1) % 10 == 0:
            torch.save(
                model.state_dict(), 'checkpoints/%s/%s_epoch_%s.pkl' %
                (args.exp_name, args.exp_name, str(epoch)))
    torch.save(model.state_dict(),
               'checkpoints/%s/%s.pkl' % (args.exp_name, args.exp_name))
        reg_result = o3d.registration.registration_icp(
            pc1_o3d,
            pc2_o3d,
            max_corr_dist,
            criteria=o3d.registration.ICPConvergenceCriteria(
                max_iteration=1000))
        end = time.time()
        opt_time = end - start

        # The transformation matrix is a 4x4 np array.
        icp_trans = reg_result.transformation

        # Transform p1 with ICP transformation, visualize target p2 and transformed p1.
        trans_pc1 = utils.transform_pointcloud(pc1, icp_trans)

        c_d = utils.chamfer_distance(pc2, trans_pc1, device=device)

        gt_cloud_distance = torch.nn.MSELoss()(torch.from_numpy(pc2),
                                               torch.from_numpy(trans_pc1))

        times.append(opt_time)
        chamfer_distances.append(c_d.item())
        gt_cloud_distances.append(gt_cloud_distance)
        rmse_distances.append(reg_result.inlier_rmse)

        if verbose:
            point_sets = np.array([trans_pc1, pc2])
            vis.visualize_points_overlay(point_sets,
                                         out_file=os.path.join(
                                             vis_dir_i, 'result_icp.png'))
        target_pc = get_pc(target_pts, target_normal, [0, 0.651, 0.929])
        # 重叠部分配准
        source_overlap_pc, target_overlap_pc = o3d.PointCloud(
        ), o3d.PointCloud()
        source_overlap_pc.points, target_overlap_pc.points = o3d.Vector3dVector(
            np.asarray(
                source_pc.points)[source_key_pts_idx]), o3d.Vector3dVector(
                    np.asarray(target_pc.points)[target_key_pts_idx])

        ransac_T = ransac_pose_estimation(source_pts[source_key_pts_idx],
                                          target_pts[target_key_pts_idx],
                                          source_f[source_key_pts_idx],
                                          target_f[target_key_pts_idx],
                                          distance_threshold=0.06,
                                          max_iter=50000,
                                          max_valid=1000)
        icp_result = o3d.registration_icp(source_overlap_pc,
                                          target_overlap_pc,
                                          0.04,
                                          init=ransac_T)
        icp_T = icp_result.transformation

        # 评估
        chamfer_dist = chamfer_distance(source_pts, target_pts, raw_pts, icp_T,
                                        T)
        print(chamfer_dist)
        o3d.draw_geometries([source_pc.transform(T), target_pc],
                            window_name="test registration",
                            width=1000,
                            height=800)
    def __call__(self) -> float:
        self.encoder = models.Encoder(self.args)
        self.encoder.load_state_dict(torch.load(self.args.save_directory))
        self.encoder.cuda()
        self.encoder.eval()

        train_data = data_loaders.mesh_loader_touch(self.classes,
                                                    self.args,
                                                    produce_sheets=True)
        train_data.names = train_data.names[self.args.start:self.args.end]
        train_loader = DataLoader(train_data,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=16,
                                  collate_fn=train_data.collate)

        for k, batch in enumerate(tqdm(train_loader, smoothing=0)):
            # initialize data
            sim_touch = batch['sim_touch'].cuda()
            depth = batch['depth'].cuda()
            ref_frame = batch['ref']

            # predict point cloud
            with torch.no_grad():
                pred_depth, sampled_points = self.encoder(
                    sim_touch, depth, ref_frame, empty=batch['empty'].cuda())

            # optimize touch chart
            for points, dir in zip(sampled_points, batch['save_dir']):
                if os.path.exists(dir):
                    continue
                directory = dir[:-len(dir.split('/')[-1])]
                if not os.path.exists(directory):
                    os.makedirs(directory)

                # if not a successful touch
                if torch.abs(points).sum() == 0:
                    np.save(dir, np.zeros(1))
                    continue

                # make initial mesh match touch sensor when touch occurred
                initial = self.verts.clone().unsqueeze(0)
                pos = ref_frame['pos'].cuda().view(1, -1)
                rot = ref_frame['rot_M'].cuda().view(1, 3, 3)
                initial = torch.bmm(rot, initial.permute(0, 2,
                                                         1)).permute(0, 2, 1)
                initial += pos.view(1, 1, 3)
                initial = initial[0]

                # set up optimization
                updates = torch.zeros(self.verts.shape,
                                      requires_grad=True,
                                      device="cuda")
                optimizer = optim.Adam([updates], lr=0.003, weight_decay=0)
                last_improvement = 0
                best_loss = 10000

                while True:
                    # update
                    optimizer.zero_grad()
                    verts = initial + updates

                    # losses
                    surf_loss = utils.chamfer_distance(
                        verts.unsqueeze(0),
                        self.faces,
                        points.unsqueeze(0),
                        num=self.args.num_samples)
                    edge_lengths = utils.batch_calc_edge(
                        verts.unsqueeze(0), self.faces)
                    loss = self.args.surf_co * surf_loss + 70 * edge_lengths

                    # optimize
                    loss.backward()
                    optimizer.step()

                    # check results
                    if loss < 0.0006:
                        break
                    if best_loss > loss:
                        best_loss = loss
                        best_verts = verts.clone()
                        last_improvement = 0
                    else:
                        last_improvement += 1
                        if last_improvement > 50:
                            break

                np.save(dir, best_verts.data.cpu().numpy())