コード例 #1
0
def load_net_model():
        parser = argparse.ArgumentParser()
        parser.add_argument('--seg_model', type=str,
                            default='D:/DenseFusion_1.0_20210308/Seg_Pose/Seg/trained_models/model_2_0.0033125885542725133.pth',
                            help='resume Seg model')
        parser.add_argument('--model', type=str,
                            default='D:/DenseFusion_1.0_20210308/Seg_Pose/Pose/trained_models/pose_model_2_0.09233990000864985.pth',
                            help='resume PoseNet model')
        parser.add_argument('--refine_model', type=str,
                            default='D:/DenseFusion_1.0_20210308/Seg_Pose/Pose/trained_models/pose_refine_model_4_0.06547849291308346.pth',
                            help='resume PoseRefineNet model')
        opt = parser.parse_args()

        global estimator
        estimator = PoseNet(num_points=num_points, num_obj=num_objects)
        estimator.cuda()

        global refiner
        refiner = PoseRefineNet(num_points=num_points, num_obj=num_objects)
        refiner.cuda()

        estimator.load_state_dict(torch.load(opt.model))
        refiner.load_state_dict(torch.load(opt.refine_model))
        estimator.eval()
        refiner.eval()

        global testdataset
        testdataset = sp_Dataset(opt.seg_model, num_points, False, 0.0, True)
コード例 #2
0
def main(args):
    seg_model = segnet()
    seg_model.cuda()

    #seg_model.load_state_dict(torch.load('./segschaltgabel.pth'))     #uncomment if we are using the original segnet
    #seg_model.load_state_dict(torch.load('./schaltgabel_pruned.pth'))
    #seg_model.load_state_dict(torch.load('./stift_pruned.pth'))
    #seg_model.load_state_dict(torch.load('./flansch_pruned.pth'))

    if opt.model == 'flansch':
        seg_model.load_state_dict(torch.load('./segflansch.pth'))
        #seg_model.load_state_dict(torch.load('./flansch_pruned.pth'))
        idx = 0
        scaled = np.array([[59, -42, 59], [59, -42, -59], [59, 20, -59],
                           [59, 20, 59], [-59, -42, 59], [-59, -42, -59],
                           [-59, 20, -59], [-59, 20, 59]
                           ]) / 1000  #flansch_3d_bbox
    elif opt.model == 'schaltgabel':
        seg_model.load_state_dict(torch.load('./segschaltgabel.pth'))
        #seg_model.load_state_dict(torch.load('./schaltgabel_pruned.pth'))
        idx = 1
        scaled = np.array([[-54.7478, -16.7500, 23], [-54.7478, -16.7500, 0],
                           [54.7478, -16.7500, 0], [54.7478, -16.7500, 23],
                           [-54.7478, 130.85, 23], [-54.7478, 130.85, 0],
                           [54.7478, 130.85, 0], [54.7478, 130.85, 23]
                           ]) / 1000  #schaltgabel_3d_bbox
    else:
        #seg_model.load_state_dict(torch.load('./stift_pruned.pth'))
        seg_model.load_state_dict(torch.load('./segstift.pth'))
        idx = 2
        scaled = np.array(
            [[-20, -35, 19.9878], [-20, -35, -19.9878], [20, -35, -19.9878],
             [20, -35, 19.9878], [-20, 87, 19.9878], [-20, 87, -19.9878],
             [20, 87, -19.9878], [20, 87, 19.9878]]) / 1000  #stift_3d_bbox

    seg_model.eval()
    estimator = PoseNet(num_points, num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points, num_objects)
    refiner.cuda()
    estimator.load_state_dict(torch.load('./pose_model_current.pth'))
    refiner.load_state_dict(torch.load('./pose_refine_model_current.pth'))
    estimator.eval()
    refiner.eval()

    pe = pose_estimation(seg_model, estimator, refiner, idx, scaled)
    rospy.init_node('pose_estimation', anonymous=True)
    try:
        rospy.spin()
    except KeyboardInterrupt:
        print('Shutting down ROS pose estimation module')
    cv2.destroyAllWindows()
コード例 #3
0
def main():
    # Globals
    global refiner, estimator

    # Setup upload folder and run server
    if not os.path.exists(UPLOAD_FOLDER):
        os.makedirs(UPLOAD_FOLDER)

    # Sets up network
    refiner = PoseRefineNet(num_points=num_points, num_obj=num_obj)
    refiner.cuda()
    refiner.load_state_dict(torch.load(opt.refine_model))
    refiner.eval()

    estimator = PoseNet(num_points=num_points, num_obj=num_obj)
    estimator.cuda()
    estimator.load_state_dict(torch.load(opt.model))
    estimator.eval()

    app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
    app.run(port=PORT, host='0.0.0.0', debug=False)
コード例 #4
0
class DenseFusionRefine:
    def __init__(self, dataset):
        self.num_points = dataset.df_num_points
        self.obj_list = dataset.obj_ids

        self.refiner = PoseRefineNet(num_points=self.num_points,
                                     num_obj=len(self.obj_list))
        self.refiner.cuda()
        self.refiner.load_state_dict(torch.load(dataset.df_model_path))
        self.refiner.eval()

    def refine(self, obj_id, pose, emb, cloud, iterations=1):
        class_id = self.obj_list.index(obj_id)
        index = torch.LongTensor([class_id])
        index = Variable(index).cuda()

        pose_ref = pose.copy()
        for _ in range(iterations):
            # transform cloud according to pose estimate
            R = Variable(torch.from_numpy(pose_ref[:3, :3].astype(
                np.float32))).cuda().view(1, 3, 3)
            ts = Variable(torch.from_numpy(pose_ref[:3, 3].astype(np.float32))).cuda().view(1, 3)\
                .repeat(self.num_points, 1).contiguous().view(1, self.num_points, 3)
            new_cloud = torch.bmm((cloud - ts), R).contiguous()

            # predict delta pose
            pred_q, pred_t = self.refiner(new_cloud, emb, index)
            pred_q = pred_q.view(1, 1, -1)
            pred_q = pred_q / (torch.norm(pred_q, dim=2).view(1, 1, 1))
            pred_q = pred_q.view(-1).cpu().data.numpy()
            pred_t = pred_t.view(-1).cpu().data.numpy()

            # apply delta to get new pose estimate
            pose_delta = quaternion_matrix(pred_q)
            pose_delta[:3, 3] = pred_t
            pose_ref = pose_ref @ pose_delta

        return pose_ref
コード例 #5
0
    model_points_list.append(model_points)

model = "/home/tomo/densefusion/trained_models/tomo_blisters_2/pose_model_1_0.005430417195209802.pth"
refine_model = "/home/tomo/densefusion/trained_models/tomo_blisters_2/pose_refine_model_63_0.0005920357997828468.pth"

#model = "/home/tomo/densefusion/trained_models/tomo_blisters_invert/pose_model_1_0.005256925600131684.pth"
#refine_model = "/home/tomo/densefusion/trained_models/tomo_blisters_invert/pose_refine_model_29_0.0005998273272780352.pth"

estimator = PoseNet(num_points=num_points, num_obj=num_objects)
estimator.cuda()
refiner = PoseRefineNet(num_points=num_points, num_obj=num_objects)
refiner.cuda()
estimator.load_state_dict(torch.load(model))
refiner.load_state_dict(torch.load(refine_model))
estimator.eval()
refiner.eval()


def setup_config():
    model_path = "/home/tomo/densefusion/MASK_RCNN_TOMO/trained_models/model_final_new.pth"
    cfg = get_cfg()
    cfg.merge_from_file(
        model_zoo.get_config_file(
            "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
    cfg.INPUT.MIN_SIZE_TEST = 0
    cfg.INPUT.MAX_SIZE_TEST = 0
    cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 2000
    cfg.MODEL.WEIGHTS = model_path
    cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[8, 16, 32, 64, 128]]
    cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 512
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256
コード例 #6
0
ファイル: train.py プロジェクト: TianXu1991/KUKA_IIWA_System
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    if opt.dataset == 'ycb':
        opt.num_objects = 21  #number of object classes in the dataset
        opt.num_points = 1000  #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb'  #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb'  #folder to save logs
        opt.repeat_epoch = 1  #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)
    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    print(
        '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
        .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%d' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data
                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                pred_r, pred_t, pred_c, emb = estimator(
                    img, points, choose, idx)
                loss, dis, new_points, new_target = criterion(
                    pred_r, pred_t, pred_c, target, model_points, idx, points,
                    opt.w, opt.refine_start)

                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(
                            pred_r, pred_t, new_target, model_points, idx,
                            new_points)
                        dis.backward()
                else:
                    loss.backward()

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info(
                        'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, train_dis_avg / opt.batch_size))
                    optimizer.step()
                    optimizer.zero_grad()
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(
                            refiner.state_dict(),
                            '{0}/pose_refine_model_current.pth'.format(
                                opt.outf))
                    else:
                        torch.save(
                            estimator.state_dict(),
                            '{0}/pose_model_current.pth'.format(opt.outf))

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%d_test' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c,
                                                       target, model_points,
                                                       idx, points, opt.w,
                                                       opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(
                        pred_r, pred_t, new_target, model_points, idx,
                        new_points)

            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), test_count,
                dis))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)),
            epoch, test_dis))
        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            if opt.dataset == 'ycb':
                dataset = PoseDataset_ycb('train', opt.num_points, True,
                                          opt.dataset_root, opt.noise_trans,
                                          opt.refine_start)
            elif opt.dataset == 'linemod':
                dataset = PoseDataset_linemod('train', opt.num_points, True,
                                              opt.dataset_root,
                                              opt.noise_trans,
                                              opt.refine_start)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=opt.workers)
            if opt.dataset == 'ycb':
                test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                               opt.dataset_root, 0.0,
                                               opt.refine_start)
            elif opt.dataset == 'linemod':
                test_dataset = PoseDataset_linemod('test', opt.num_points,
                                                   False, opt.dataset_root,
                                                   0.0, opt.refine_start)
            testdataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                num_workers=opt.workers)

            opt.sym_list = dataset.get_sym_list()
            opt.num_points_mesh = dataset.get_num_points_mesh()

            print(
                '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
                .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                        opt.sym_list))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
コード例 #7
0
ファイル: evaluate.py プロジェクト: r-pad/DenseFusion
class DenseFusionEstimator(object):
    def __init__(self,
                 num_points,
                 num_obj,
                 estimator_weights_file,
                 refiner_weights_file=None,
                 batch_size=1):
        self.num_points = num_points
        self.bs = batch_size
        self.estimator = PoseNet(num_points=num_points, num_obj=num_obj)
        if (torch.cuda.is_available()):
            self.estimator.cuda()
            self.estimator.load_state_dict(torch.load(estimator_weights_file))
        else:
            self.estimator.load_state_dict(
                torch.load(estimator_weights_file,
                           map_location=lambda storage, loc: storage))

        self.estimator.eval()

        if (refiner_weights_file is not None):
            self.refiner = PoseRefineNet(num_points=num_points,
                                         num_obj=num_obj)
            if (torch.cuda.is_available()):
                self.refiner.cuda()
            self.refiner.load_state_dict(torch.load(refiner_weights_file))
            self.refiner.eval()
        else:
            self.refiner = None

    def __call__(self, img, depth, mask, bbox, object_label, return_all=False):
        object_label = Variable(torch.LongTensor([object_label]))
        if (torch.cuda.is_available()):
            object_label = object_label.cuda()
        data = preprocessData(img,
                              depth,
                              mask,
                              bbox,
                              num_points=self.num_points)
        if (data is None):
            return None, None
        img_masked, cloud, choose = data
        max_r, max_t, max_c, pred_r, pred_t, pred_c, emb = self.estimatePose(
            img_masked, cloud, choose, object_label, return_all=True)
        if (self.refiner is not None):
            refined_r, refined_t = self.refinePose(emb, cloud, object_label,
                                                   max_r, max_t)
            return refined_r, refined_t

        if (return_all):
            return max_r, max_t, max_c, pred_r, pred_t, pred_c, emb

        return max_r, max_t

    def globalFeature(self, img_masked, cloud, choose, object_label):
        feat = self.estimator.globalFeature(img_masked, cloud, choose,
                                            object_label)
        return feat

    def estimatePose(self,
                     img_masked,
                     cloud,
                     choose,
                     object_label,
                     return_all=True):
        pred_r, pred_t, pred_c, emb = self.estimator(img_masked, cloud, choose,
                                                     object_label)
        pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, self.num_points, 1)
        pred_c = pred_c.view(self.bs, self.num_points)
        max_c, which_max = torch.max(pred_c, 1)
        pred_t = pred_t.view(self.bs * self.num_points, 1, 3)
        points = cloud.view(self.bs * self.num_points, 1, 3)

        max_r = pred_r[0][which_max[0]].view(-1)  #.cpu().data.numpy()
        max_t = (points + pred_t)[which_max[0]].view(-1)  #.cpu().data.numpy()
        #max_c = max_c.cpu().data.numpy()

        if (return_all):
            return max_r, max_t, max_c, pred_r, pred_t, pred_c, emb

        return max_r, max_t, max_c

    def refinePose(self,
                   emb,
                   cloud,
                   object_label,
                   init_t,
                   init_r,
                   iterations=2):
        init_t = init_t.cpu().data.numpy()
        init_r = init_r.cpu().data.numpy()

        for ite in range(0, iteration):
            T = Variable(torch.from_numpy(
                init_t.astype(np.float32))).cuda().view(1, 3).repeat(
                    num_points, 1).contiguous().view(1, self.num_points, 3)
            init_mat = quaternion_matrix(init_r)
            R = Variable(torch.from_numpy(init_mat[:3, :3].astype(
                np.float32))).cuda().view(1, 3, 3)
            init_mat[0:3, 3] = init_t

            new_cloud = torch.bmm((cloud - T), R).contiguous()
            pred_r, pred_t = self.refiner(new_cloud, emb, object_label)
            pred_r = pred_r.view(1, 1, -1)
            pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))

            delta_r = pred_r.view(-1).cpu().data.numpy()
            delta_t = pred_t.view(-1).cpu().data.numpy()
            delta_mat = quaternion_matrix(delta_r)

            delta_mat[0:3, 3] = delta_t

            refined_mat = np.dot(init_mat, delta_mat)
            refined_r = copy.deepcopy(refined_mat)
            refined_r[0:3, 3] = 0
            refined_r = quaternion_from_matrix(refined_r, True)
            refined_t = np.array(
                [refined_mat[0][3], refined_mat[1][3], refined_mat[2][3]])

            init_r = r_final
            init_t = t_final
        return refined_t, refined_t
コード例 #8
0
class Inference:
    def __init__(self,
                 model,
                 model_with_refinement,
                 width,
                 height,
                 fx,
                 fy,
                 cx,
                 cy,
                 number_points=1000,
                 number_iterations=2):

        self.width = width
        self.height = height
        self.fx = fx
        self.fy = fy
        self.cx = cx
        self.cy = cy

        self.number_iterations = number_iterations

        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        self.border_list = [
            -1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520,
            560, 600, 640, 680
        ]

        self.xmap = numpy.array([[j for i in range(self.width)]
                                 for j in range(self.height)])
        self.ymap = numpy.array([[i for i in range(self.width)]
                                 for j in range(self.height)])
        self.number_points = number_points

        self.estimator = PoseNet(num_points=self.number_points, num_obj=21)
        self.estimator.cuda()
        self.estimator.load_state_dict(torch.load(model))
        self.estimator.eval()

        self.refiner = PoseRefineNet(num_points=self.number_points, num_obj=21)
        self.refiner.cuda()
        self.refiner.load_state_dict(torch.load(model_with_refinement))
        self.refiner.eval()

    def evaluate_bbox(self, mask):
        points = cv2.findNonZero(mask)
        x, y, w, h = cv2.boundingRect(points)
        rmin = y
        rmax = y + h
        cmin = x
        cmax = x + w

        r_b = rmax - rmin
        for tt in range(len(self.border_list)):
            if r_b > self.border_list[tt] and r_b < self.border_list[tt + 1]:
                r_b = self.border_list[tt + 1]
                break

        c_b = cmax - cmin
        for tt in range(len(self.border_list)):
            if c_b > self.border_list[tt] and c_b < self.border_list[tt + 1]:
                c_b = self.border_list[tt + 1]
                break

        center = [int((rmin + rmax) / 2), int((cmin + cmax) / 2)]
        rmin = center[0] - int(r_b / 2)
        rmax = center[0] + int(r_b / 2)
        cmin = center[1] - int(c_b / 2)
        cmax = center[1] + int(c_b / 2)

        if rmin < 0:
            delt = -rmin
            rmin = 0
            rmax += delt
        if cmin < 0:
            delt = -cmin
            cmin = 0
            cmax += delt
        if rmax > self.height:
            delt = rmax - self.height
            rmax = self.height
            rmin -= delt
        if cmax > self.width:
            delt = cmax - self.width
            cmax = self.width
            cmin -= delt

        return rmin, rmax, cmin, cmax

    def inference(self, object_id, img, depth, mask):
        try:
            rmin, rmax, cmin, cmax = self.evaluate_bbox(mask)

            mask_depth = numpy_ma.getmaskarray(
                numpy_ma.masked_not_equal(depth, 0))
            mask_label = (mask == 255)
            mask = mask_label * mask_depth

            choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]
            if len(choose) > self.number_points:
                c_mask = numpy.zeros(len(choose), dtype=int)
                c_mask[:self.number_points] = 1
                numpy.random.shuffle(c_mask)
                choose = choose[c_mask.nonzero()]
            else:
                choose = numpy.pad(choose,
                                   (0, self.number_points - len(choose)),
                                   'wrap')

            depth_masked = depth[
                rmin:rmax, cmin:cmax].flatten()[choose][:,
                                                        numpy.newaxis].astype(
                                                            numpy.float32)
            xmap_masked = self.xmap[
                rmin:rmax, cmin:cmax].flatten()[choose][:,
                                                        numpy.newaxis].astype(
                                                            numpy.float32)
            ymap_masked = self.ymap[
                rmin:rmax, cmin:cmax].flatten()[choose][:,
                                                        numpy.newaxis].astype(
                                                            numpy.float32)
            choose = numpy.array([choose])

            pt2 = depth_masked
            pt0 = (ymap_masked - self.cx) * pt2 / self.fx
            pt1 = (xmap_masked - self.cy) * pt2 / self.fy
            cloud = numpy.concatenate((pt0, pt1, pt2), axis=1)

            img_masked = numpy.array(img)[:, :, :3]
            img_masked = numpy.transpose(img_masked, (2, 0, 1))
            img_masked = img_masked[:, rmin:rmax, cmin:cmax]

            cloud = torch.from_numpy(cloud.astype(numpy.float32))
            choose = torch.LongTensor(choose.astype(numpy.int32))
            img_masked = self.norm(
                torch.from_numpy(img_masked.astype(numpy.float32)))
            index = torch.LongTensor([object_id - 1])

            cloud = Variable(cloud).cuda()
            choose = Variable(choose).cuda()
            img_masked = Variable(img_masked).cuda()
            index = Variable(index).cuda()

            cloud = cloud.view(1, self.number_points, 3)
            img_masked = img_masked.view(1, 3,
                                         img_masked.size()[1],
                                         img_masked.size()[2])

            pred_r, pred_t, pred_c, emb = self.estimator(
                img_masked, cloud, choose, index)
            pred_r = pred_r / torch.norm(pred_r, dim=2).view(
                1, self.number_points, 1)

            pred_c = pred_c.view(1, self.number_points)
            how_max, which_max = torch.max(pred_c, 1)
            pred_t = pred_t.view(1 * self.number_points, 1, 3)
            points = cloud.view(1 * self.number_points, 1, 3)

            my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy()
            my_t = (points + pred_t)[which_max[0]].view(-1).cpu().data.numpy()
            my_pred = numpy.append(my_r, my_t)

            for ite in range(0, self.number_iterations):
                T = Variable(torch.from_numpy(my_t.astype(
                    numpy.float32))).cuda().view(1, 3).repeat(
                        self.number_points,
                        1).contiguous().view(1, self.number_points, 3)
                my_mat = quaternion_matrix(my_r)
                R = Variable(
                    torch.from_numpy(my_mat[:3, :3].astype(
                        numpy.float32))).cuda().view(1, 3, 3)
                my_mat[0:3, 3] = my_t

                new_cloud = torch.bmm((cloud - T), R).contiguous()
                pred_r, pred_t = self.refiner(new_cloud, emb, index)
                pred_r = pred_r.view(1, 1, -1)
                pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
                my_r_2 = pred_r.view(-1).cpu().data.numpy()
                my_t_2 = pred_t.view(-1).cpu().data.numpy()
                my_mat_2 = quaternion_matrix(my_r_2)

                my_mat_2[0:3, 3] = my_t_2

                my_mat_final = numpy.dot(my_mat, my_mat_2)
                my_r_final = copy.deepcopy(my_mat_final)
                my_r_final[0:3, 3] = 0
                my_r_final_copy = copy.deepcopy(my_r_final)
                my_r_final = quaternion_from_matrix(my_r_final, True)
                my_t_final = numpy.array([
                    my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]
                ])

                my_pred = numpy.append(my_r_final, my_t_final)
                my_r_final_alt = Quaternion(matrix=my_r_final_copy[0:3, 0:3])
                my_r_final_alt_aa = numpy.append(my_r_final_alt.axis.squeeze(),
                                                 my_r_final_alt.angle)
                my_pred_alt = numpy.append(my_t_final, my_r_final_alt_aa)
                my_r = my_r_final
                my_t = my_t_final

            return my_pred_alt
        except ValueError:
            print(
                "Inference::inference. Error: an error occured at inference time."
            )
            return []
コード例 #9
0
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    opt.num_objects = 21 #number of object classes in the dataset
    opt.num_points = 1000 #number of points on the input pointcloud
    opt.outf = 'trained_models/ycb_global_mnorm' #folder to save trained models
    opt.log_dir = 'experiments/logs/ycb_global_mnorm' #folder to save logs
    opt.repeat_epoch = 1 #number of repeat times for one epoch training


    if not os.path.exists(opt.outf):
        os.makedirs(opt.outf)
 
    if not os.path.exists(opt.log_dir):
        os.makedirs(opt.log_dir)


    estimator = PoseNetGlobal(num_points = opt.num_points, num_obj = opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    object_list = list(range(1,22))
    output_format = [otypes.DEPTH_POINTS_MASKED_AND_INDEXES,
                     otypes.IMAGE_CROPPED,
                     otypes.MODEL_POINTS_TRANSFORMED,
                     otypes.MODEL_POINTS,
                     otypes.OBJECT_LABEL,
                     ]
        
    dataset = YCBDataset(opt.dataset_root, mode='train_syn_grid', 
                         object_list = object_list, 
                         output_data = output_format,
                         resample_on_error = True,
                         preprocessors = [YCBOcclusionAugmentor(opt.dataset_root), 
                                          ColorJitter(), 
                                          InplaneRotator()],
                         postprocessors = [ImageNormalizer(), PointMeanNormalizer(0)], #PointShifter()],
                         refine = opt.refine_start,
                         image_size = [640, 480], num_points=1000)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers-1)
    
    test_dataset = YCBDataset(opt.dataset_root, mode='valid', 
                         object_list = object_list, 
                         output_data = output_format,
                         resample_on_error = True,
                         preprocessors = [],
                         postprocessors = [ImageNormalizer(), PointMeanNormalizer(0)],
                         refine = opt.refine_start,
                         image_size = [640, 480], num_points=1000)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
    
    opt.sym_list = [12, 15, 18, 19, 20]
    opt.num_points_mesh = dataset.num_pt_mesh_small

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'.format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data
                idx = idx - 1
                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
                loss, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)
                
                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)
                        dis.backward()
                else:
                    loss.backward()

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info('Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, int(train_count / opt.batch_size), train_count, train_dis_avg / opt.batch_size))
                    optimizer.step()
                    optimizer.zero_grad()
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(refiner.state_dict(), '{0}/pose_refine_model_current.pth'.format(opt.outf))
                    else:
                        torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(opt.outf))

        print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch))


        logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            idx = idx - 1
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)

            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, dis))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_dis))
        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(refiner.state_dict(), '{0}/pose_refine_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
            else:
                torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
            print(epoch, '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            dataset = YCBDataset(opt.dataset_root, mode='train_syn_grid', 
                                 object_list = object_list, 
                                 output_data = output_format,
                                 resample_on_error = True,
                                 preprocessors = [YCBOcclusionAugmentor(opt.dataset_root), 
                                                  ColorJitter(), 
                                                  InplaneRotator()],
                                 postprocessors = [ImageNormalizer(), PointShifter()],
                                 refine = opt.refine_start,
                                 image_size = [640, 480], num_points=1000)
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
            
            test_dataset = YCBDataset(opt.dataset_root, mode='valid', 
                                 object_list = object_list, 
                                 output_data = output_format,
                                 resample_on_error = True,
                                 preprocessors = [],
                                 postprocessors = [ImageNormalizer()],
                                 refine = opt.refine_start,
                                 image_size = [640, 480], num_points=1000)
            testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
            opt.num_points_mesh = dataset.num_pt_mesh_large

            print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'.format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
コード例 #10
0
class DenseFusion_YW(object):
    def __init__(self, posecnn):
        # posecnn
        self.posecnn = posecnn

        # df
        #camera paras
        self.cam_cx = 312.9869  #???
        self.cam_cy = 241.3109  #???
        self.cam_fx = 1066.778  #x方向焦距,水平
        self.cam_fy = 1067.487  #y方向焦距,垂直
        self.cam_scale = 10000.0  # scale ????
        #main
        self.num_points = 1000  # default
        self.num_obj = 21  # ycb 已经训练好了,含有21个类,不能改。
        self.xmap = np.array([[j for i in range(640)] for j in range(480)
                              ])  # [0000...,1111...,222] 480,640
        self.ymap = np.array([
            [i for i in range(640)] for j in range(480)
        ])  # [0 1 2 3..., 0 1 2 3 ..., 0 1 2 3 ...] 480,640
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225
                                              ])  # For img transform in the df
        self.bs = 1
        self.iteration = 2
        self.esti_model_file = "src/pose_model_26_0.012863246640872631.pth"
        self.refine_model_file = "src/pose_refine_model_69_0.009449292959118935.pth"
        self.estimator = PoseNet(num_points=self.num_points,
                                 num_obj=self.num_obj)
        self.estimator.cuda()
        self.estimator.load_state_dict(torch.load(self.esti_model_file))
        self.estimator.eval()
        self.refiner = PoseRefineNet(num_points=self.num_points,
                                     num_obj=self.num_obj)
        self.refiner.cuda()
        self.refiner.load_state_dict(torch.load(self.refine_model_file))
        self.refiner.eval()
        self.lst = [
            1
        ]  # assume only banana. This is designed for multi-label task, if multi label detected in one img,
        #they may allocated with different id, id==1.it means the label is 1, id==2, it means the label is 2; so with known the
        #label, we can know label1 is banana ,label2 is apple like this.
        self.my_result_wo_refine = []  # ???
        self.my_result = []

    def DenseFusion(self, img, depth, posecnn_res):
        my_result_wo_refine = []

        itemid = 1  # this is simplified for single label decttion, if multi-label used, check DFYW3.py for more

        depth = np.array(depth)
        # img = img

        seg_res = posecnn_res

        x1, y1, x2, y2 = seg_res["box"]
        banana_bbox_draw = self.posecnn.get_box_rcwh(seg_res["box"])
        rmin, rmax, cmin, cmax = int(x1), int(x2), int(y1), int(y2)
        depth = depth[:, :,
                      1]  # because depth has 3 dimensions RGB but they are the all the same with each other
        mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))  # ok

        label_banana = np.squeeze(seg_res["mask"])
        label_banana = ma.getmaskarray(ma.masked_greater(label_banana, 0.5))
        label_banana_nonzeros = label_banana.flatten().nonzero()

        mask_label = ma.getmaskarray(ma.masked_equal(
            label_banana, itemid))  # label from banana label
        mask = mask_label * mask_depth

        mask_nonzeros = mask[:].flatten().nonzero()
        choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]
        if len(choose) > self.num_points:
            c_mask = np.zeros(len(choose), dtype=int)
            c_mask[:self.num_points] = 1
            np.random.shuffle(c_mask)
            choose = choose[c_mask.nonzero()]
        else:
            print("len of choose is 0, check error")
            choose = np.pad(choose, (0, self.num_points - len(choose)), 'wrap')

        depth_masked = depth[rmin:rmax,
                             cmin:cmax].flatten()[choose][:,
                                                          np.newaxis].astype(
                                                              np.float32)
        xmap_masked = self.xmap[
            rmin:rmax,
            cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
        ymap_masked = self.ymap[
            rmin:rmax,
            cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
        choose = np.array([choose])
        pt2 = depth_masked / self.cam_scale
        pt0 = (ymap_masked - self.cam_cx) * pt2 / self.cam_fx
        pt1 = (xmap_masked - self.cam_cy) * pt2 / self.cam_fy
        cloud = np.concatenate((pt0, pt1, pt2), axis=1)
        img_np = np.array(img)
        img_masked = np.array(img)[:, :, :3]
        img_masked = np.transpose(img_masked, (2, 0, 1))
        img_masked = img_masked[:, rmin:rmax, cmin:cmax]

        cloud = torch.from_numpy(cloud.astype(np.float32))
        choose = torch.LongTensor(choose.astype(np.int32))
        img_masked = self.norm(torch.from_numpy(img_masked.astype(np.float32)))
        index = torch.LongTensor([itemid - 1])

        cloud = Variable(cloud).cuda()
        choose = Variable(choose).cuda()
        img_masked = Variable(img_masked).cuda()
        index = Variable(index).cuda()

        cloud = cloud.view(1, self.num_points, 3)
        img_masked = img_masked.view(1, 3,
                                     img_masked.size()[1],
                                     img_masked.size()[2])

        pred_r, pred_t, pred_c, emb = self.estimator(img_masked, cloud, choose,
                                                     index)
        pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, self.num_points, 1)

        pred_c = pred_c.view(self.bs, self.num_points)
        how_max, which_max = torch.max(pred_c, 1)
        pred_t = pred_t.view(self.bs * self.num_points, 1, 3)
        points = cloud.view(self.bs * self.num_points, 1, 3)

        my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy()
        my_t = (points + pred_t)[which_max[0]].view(-1).cpu().data.numpy()
        my_pred = np.append(my_r, my_t)
        my_result_wo_refine.append(my_pred.tolist())

        my_result = []
        for ite in range(0, self.iteration):
            T = Variable(torch.from_numpy(
                my_t.astype(np.float32))).cuda().view(1, 3).repeat(
                    self.num_points,
                    1).contiguous().view(1, self.num_points, 3)
            my_mat = quaternion_matrix(my_r)
            R = Variable(torch.from_numpy(my_mat[:3, :3].astype(
                np.float32))).cuda().view(1, 3, 3)
            my_mat[0:3, 3] = my_t

            new_cloud = torch.bmm((cloud - T), R).contiguous()
            pred_r, pred_t = self.refiner(new_cloud, emb, index)
            pred_r = pred_r.view(1, 1, -1)
            pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
            my_r_2 = pred_r.view(-1).cpu().data.numpy()
            my_t_2 = pred_t.view(-1).cpu().data.numpy()
            my_mat_2 = quaternion_matrix(my_r_2)

            my_mat_2[0:3, 3] = my_t_2
            my_mat_final = np.dot(my_mat, my_mat_2)
            my_r_final = copy.deepcopy(my_mat_final)
            my_r_final[0:3, 3] = 0
            my_r_final = quaternion_from_matrix(my_r_final, True)
            my_t_final = np.array(
                [my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]])

            my_pred = np.append(my_r_final, my_t_final)
            my_result.append(my_pred.tolist())
        my_result_np = np.array(my_result)
        my_result_mean = np.mean(my_result, axis=0)
        my_r = my_result_mean[:4]
        my_t = my_result_mean[4:]
        my_r_quaternion = my_r
        return my_r_quaternion, my_t
コード例 #11
0
def main():
    # g13: parameter setting -------------------
    batch_id = 1
    
    opt.dataset ='linemod'
    opt.dataset_root = './datasets/linemod/Linemod_preprocessed'
    estimator_path = 'trained_checkpoints/linemod/pose_model_9_0.01310166542980859.pth'
    refiner_path = 'trained_checkpoints/linemod/pose_refine_model_493_0.006761023565178073.pth'
    opt.resume_posenet = estimator_path
    opt.resume_posenet = refiner_path
    dataset_config_dir = 'datasets/linemod/dataset_config'
    output_result_dir = 'experiments/eval_result/linemod'
    bs = 1 #fixed because of the default setting in torch.utils.data.DataLoader
    opt.iteration = 2 #default is 4 in eval_linemod.py
    t1_idx = 0
    t1_total_eval_num = 3
    
    axis_range = 0.1   # the length of X, Y, and Z axis in 3D
    vimg_dir = 'verify_img'
    if not os.path.exists(vimg_dir):
        os.makedirs(vimg_dir)
    #-------------------------------------------
    
    if opt.dataset == 'ycb':
        opt.num_objects = 21 #number of object classes in the dataset
        opt.num_points = 1000 #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb' #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb' #folder to save logs
        opt.repeat_epoch = 1 #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return
    
    estimator = PoseNet(num_points = opt.num_points, num_obj = opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load(estimator_path))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load(refiner_path))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)


    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
    print('complete loading testing loader\n')
    opt.sym_list = test_dataset.get_sym_list()
    opt.num_points_mesh = test_dataset.get_num_points_mesh()

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\n\
        length of the testing set: {0}\nnumber of sample points on mesh: {1}\n\
        symmetry object list: {2}'\
        .format( len(test_dataset), opt.num_points_mesh, opt.sym_list))
    
    
    
    #load pytorch model
    estimator.eval()    
    refiner.eval()
    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
    fw = open('{0}/t1_eval_result_logs.txt'.format(output_result_dir), 'w')

    #Pose estimation
    for j, data in enumerate(testdataloader, 0):
        # g13: modify this part for evaluation target--------------------
        if j == t1_total_eval_num:
            break
        #----------------------------------------------------------------
        points, choose, img, target, model_points, idx = data
        if len(points.size()) == 2:
            print('No.{0} NOT Pass! Lost detection!'.format(j))
            fw.write('No.{0} NOT Pass! Lost detection!\n'.format(j))
            continue
        points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
        pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
        _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)

        #if opt.refine_start: #iterative poserefinement
        #    for ite in range(0, opt.iteration):
        #        pred_r, pred_t = refiner(new_points, emb, idx)
        #        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)
        
        pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, opt.num_points, 1)
        pred_c = pred_c.view(bs, opt.num_points)
        how_max, which_max = torch.max(pred_c, 1)
        pred_t = pred_t.view(bs * opt.num_points, 1, 3)
    
        my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy()
        my_t = (points.view(bs * opt.num_points, 1, 3) + pred_t)[which_max[0]].view(-1).cpu().data.numpy()
        my_pred = np.append(my_r, my_t)
    
        for ite in range(0, opt.iteration):
            T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(opt.num_points, 1).contiguous().view(1, opt.num_points, 3)
            my_mat = quaternion_matrix(my_r)
            R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3)
            my_mat[0:3, 3] = my_t
            
            new_points = torch.bmm((points - T), R).contiguous()
            pred_r, pred_t = refiner(new_points, emb, idx)
            pred_r = pred_r.view(1, 1, -1)
            pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
            my_r_2 = pred_r.view(-1).cpu().data.numpy()
            my_t_2 = pred_t.view(-1).cpu().data.numpy()
            my_mat_2 = quaternion_matrix(my_r_2)
            my_mat_2[0:3, 3] = my_t_2
    
            my_mat_final = np.dot(my_mat, my_mat_2)
            my_r_final = copy.deepcopy(my_mat_final)
            my_r_final[0:3, 3] = 0
            my_r_final = quaternion_from_matrix(my_r_final, True)
            my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]])
    
            my_pred = np.append(my_r_final, my_t_final)
            my_r = my_r_final
            my_t = my_t_final

        # g13: start drawing pose on image------------------------------------
        # pick up image
        print("index {0}: {1}".format(j, test_dataset.list_rgb[j]))
        img = Image.open(test_dataset.list_rgb[j])
        
        # pick up center position by bbox
        meta_file = open('{0}/data/{1}/gt.yml'.format(opt.dataset_root, '%02d' % test_dataset.list_obj[j]), 'r')
        meta = {}
        meta = yaml.load(meta_file)
        which_item = test_dataset.list_rank[j]
        bbx = meta[which_item][0]['obj_bb']
        draw = ImageDraw.Draw(img) 
        
        # draw box (ensure this is the right object)
        draw.line((bbx[0],bbx[1], bbx[0], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1], bbx[0]+bbx[2], bbx[1]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1]+bbx[3], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0]+bbx[2],bbx[1], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        
        #get center
        c_x = bbx[0]+int(bbx[2]/2)
        c_y = bbx[1]+int(bbx[3]/2)
        draw.point((c_x,c_y), fill=(255,255,0))
        
        #get the 3D position of center
        cam_intrinsic = np.zeros((3,3))
        cam_intrinsic.itemset(0, test_dataset.cam_fx)
        cam_intrinsic.itemset(4, test_dataset.cam_fy)
        cam_intrinsic.itemset(2, test_dataset.cam_cx)
        cam_intrinsic.itemset(5, test_dataset.cam_cy)
        cam_intrinsic.itemset(8, 1)
        cam_extrinsic = my_mat_final[0:3, :]
        cam2d_3d = np.matmul(cam_intrinsic, cam_extrinsic)
        cen_3d = np.matmul(np.linalg.pinv(cam2d_3d), [[c_x],[c_y],[1]])
        # replace img.show() with plt.imshow(img)
        
        #transpose three 3D axis point into 2D
        x_3d = cen_3d + [[axis_range],[0],[0],[0]]
        y_3d = cen_3d + [[0],[axis_range],[0],[0]]
        z_3d = cen_3d + [[0],[0],[axis_range],[0]]
        x_2d = np.matmul(cam2d_3d, x_3d)
        y_2d = np.matmul(cam2d_3d, y_3d)
        z_2d = np.matmul(cam2d_3d, z_3d)
        
        #draw the axis on 2D
        draw.line((c_x, c_y, x_2d[0], x_2d[1]), fill=(255,255,0), width=5)
        draw.line((c_x, c_y, y_2d[0], y_2d[1]), fill=(0,255,0), width=5)
        draw.line((c_x, c_y, z_2d[0], z_2d[1]), fill=(0,0,255), width=5)

        #g13: show image
        #img.show()
        
        #save file under file 
        img_file_name = '{0}/pred_obj{1}_pic{2}.png'.format(vimg_dir, test_dataset.list_obj[j], which_item)
        img.save( img_file_name, "PNG" )
        img.close()
コード例 #12
0
def main():
    # g13: parameter setting -------------------
    '''
    posemodel is trained_checkpoints/linemod/pose_model_9_0.01310166542980859.pth
    refine model is trained_checkpoints/linemod/pose_refine_model_493_0.006761023565178073.pth

    '''
    objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
    knn = KNearestNeighbor(1)
    opt.dataset ='linemod'
    opt.dataset_root = './datasets/linemod/Linemod_preprocessed'
    estimator_path = 'trained_checkpoints/linemod/pose_model_9_0.01310166542980859.pth'
    refiner_path = 'trained_checkpoints/linemod/pose_refine_model_493_0.006761023565178073.pth'
    opt.model = estimator_path
    opt.refine_model = refiner_path
    dataset_config_dir = 'datasets/linemod/dataset_config'
    output_result_dir = 'experiments/eval_result/linemod'
    opt.refine_start = True
    bs = 1 #fixed because of the default setting in torch.utils.data.DataLoader
    opt.iteration = 2 #default is 4 in eval_linemod.py
    t1_start = True
    t1_idx = 0
    t1_total_eval_num = 3
    t2_start = False
    t2_target_list = [22, 30, 172, 187, 267, 363, 410, 471, 472, 605, 644, 712, 1046, 1116, 1129, 1135, 1263]
    #t2_target_list = [0, 1]
    axis_range = 0.1   # the length of X, Y, and Z axis in 3D
    vimg_dir = 'verify_img'
    diameter = []
    meta_file = open('{0}/models_info.yml'.format(dataset_config_dir), 'r')
    meta_d = yaml.load(meta_file)
    for obj in objlist:
        diameter.append(meta_d[obj]['diameter'] / 1000.0 * 0.1)
    print(diameter)
    if not os.path.exists(vimg_dir):
        os.makedirs(vimg_dir)
    #-------------------------------------------
    
    if opt.dataset == 'ycb':
        opt.num_objects = 21 #number of object classes in the dataset
        opt.num_points = 1000 #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb' #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb' #folder to save logs
        opt.repeat_epoch = 1 #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return
    
    estimator = PoseNet(num_points = opt.num_points, num_obj = opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects)
    refiner.cuda()
  
    estimator.load_state_dict(torch.load(estimator_path))    
    refiner.load_state_dict(torch.load(refiner_path))
    opt.refine_start = True
    
    test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
    
    opt.sym_list = test_dataset.get_sym_list()
    opt.num_points_mesh = test_dataset.get_num_points_mesh()

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\n\
        length of the testing set: {0}\nnumber of sample points on mesh: {1}\n\
        symmetry object list: {2}'\
        .format( len(test_dataset), opt.num_points_mesh, opt.sym_list))
   
    
    #load pytorch model
    estimator.eval()    
    refiner.eval()
    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
    fw = open('{0}/t1_eval_result_logs.txt'.format(output_result_dir), 'w')

    #Pose estimation
    for j, data in enumerate(testdataloader, 0):
        # g13: modify this part for evaluation target--------------------
        if t1_start and j == t1_total_eval_num:
            break
        if t2_start and not (j in t2_target_list):
            continue
        #----------------------------------------------------------------
        points, choose, img, target, model_points, idx = data
        if len(points.size()) == 2:
            print('No.{0} NOT Pass! Lost detection!'.format(j))
            fw.write('No.{0} NOT Pass! Lost detection!\n'.format(j))
            continue
        points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
        pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
        _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)

        #if opt.refine_start: #iterative poserefinement
        #    for ite in range(0, opt.iteration):
        #        pred_r, pred_t = refiner(new_points, emb, idx)
        #        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)
        
        pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, opt.num_points, 1)
        pred_c = pred_c.view(bs, opt.num_points)
        how_max, which_max = torch.max(pred_c, 1)
        pred_t = pred_t.view(bs * opt.num_points, 1, 3)
    
        my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy()
        my_t = (points.view(bs * opt.num_points, 1, 3) + pred_t)[which_max[0]].view(-1).cpu().data.numpy()
        my_pred = np.append(my_r, my_t)
    
        for ite in range(0, opt.iteration):
            T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(opt.num_points, 1).contiguous().view(1, opt.num_points, 3)
            my_mat = quaternion_matrix(my_r)
            R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3)
            my_mat[0:3, 3] = my_t
            
            new_points = torch.bmm((points - T), R).contiguous()
            pred_r, pred_t = refiner(new_points, emb, idx)
            pred_r = pred_r.view(1, 1, -1)
            pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
            my_r_2 = pred_r.view(-1).cpu().data.numpy()
            my_t_2 = pred_t.view(-1).cpu().data.numpy()
            my_mat_2 = quaternion_matrix(my_r_2)
            my_mat_2[0:3, 3] = my_t_2
    
            my_mat_final = np.dot(my_mat, my_mat_2)
            my_r_final = copy.deepcopy(my_mat_final)
            my_r_final[0:3, 3] = 0
            my_r_final = quaternion_from_matrix(my_r_final, True)
            my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]])
    
            my_pred = np.append(my_r_final, my_t_final)
            my_r = my_r_final
            my_t = my_t_final
            # Here 'my_pred' is the final pose estimation result after refinement ('my_r': quaternion, 'my_t': translation)
        
        #g13: checking the dis value
        success_count = [0 for i in range(opt.num_objects)]
        num_count = [0 for i in range(opt.num_objects)]
        model_points = model_points[0].cpu().detach().numpy()
        my_r = quaternion_matrix(my_r)[:3, :3]
        pred = np.dot(model_points, my_r.T) + my_t
        target = target[0].cpu().detach().numpy()
    
        if idx[0].item() in opt.sym_list:
            pred = torch.from_numpy(pred.astype(np.float32)).cuda().transpose(1, 0).contiguous()
            target = torch.from_numpy(target.astype(np.float32)).cuda().transpose(1, 0).contiguous()
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            target = torch.index_select(target, 1, inds.view(-1) - 1)
            dis = torch.mean(torch.norm((pred.transpose(1, 0) - target.transpose(1, 0)), dim=1), dim=0).item()
        else:
            dis = np.mean(np.linalg.norm(pred - target, axis=1))
    
        if dis < diameter[idx[0].item()]:
            success_count[idx[0].item()] += 1
            print('No.{0} Pass! Distance: {1}'.format(j, dis))
            fw.write('No.{0} Pass! Distance: {1}\n'.format(j, dis))
        else:
            print('No.{0} NOT Pass! Distance: {1}'.format(j, dis))
            fw.write('No.{0} NOT Pass! Distance: {1}\n'.format(j, dis))
        num_count[idx[0].item()] += 1
        
        # g13: start drawing pose on image------------------------------------
        # pick up image
        print('{0}:\nmy_r is {1}\nmy_t is {2}\ndis:{3}'.format(j, my_r, my_t, dis.item()))    
        print("index {0}: {1}".format(j, test_dataset.list_rgb[j]))
        img = Image.open(test_dataset.list_rgb[j])
        
        # pick up center position by bbox
        meta_file = open('{0}/data/{1}/gt.yml'.format(opt.dataset_root, '%02d' % test_dataset.list_obj[j]), 'r')
        meta = {}
        meta = yaml.load(meta_file)
        which_item = test_dataset.list_rank[j]
        which_obj = test_dataset.list_obj[j]
        which_dict = 0
        dict_leng = len(meta[which_item])
        #print('get meta[{0}][{1}][obj_bb]'.format(which_item, which_obj))
        k_idx = 0
        while 1:
            if meta[which_item][k_idx]['obj_id'] == which_obj:
                which_dict = k_idx
                break
            k_idx = k_idx+1
        
        bbx = meta[which_item][which_dict]['obj_bb']
        draw = ImageDraw.Draw(img) 
        
        # draw box (ensure this is the right object)
        draw.line((bbx[0],bbx[1], bbx[0], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1], bbx[0]+bbx[2], bbx[1]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1]+bbx[3], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0]+bbx[2],bbx[1], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        
        #get center
        c_x = bbx[0]+int(bbx[2]/2)
        c_y = bbx[1]+int(bbx[3]/2)
        draw.point((c_x,c_y), fill=(255,255,0))
        print('center:({0},{1})'.format(c_x, c_y))
        
        #get the 3D position of center
        cam_intrinsic = np.zeros((3,3))
        cam_intrinsic.itemset(0, test_dataset.cam_fx)
        cam_intrinsic.itemset(4, test_dataset.cam_fy)
        cam_intrinsic.itemset(2, test_dataset.cam_cx)
        cam_intrinsic.itemset(5, test_dataset.cam_cy)
        cam_intrinsic.itemset(8, 1)
        cam_extrinsic = my_mat_final[0:3, :]
        cam2d_3d = np.matmul(cam_intrinsic, cam_extrinsic)
        cen_3d = np.matmul(np.linalg.pinv(cam2d_3d), [[c_x],[c_y],[1]])
        # replace img.show() with plt.imshow(img)
        
        #transpose three 3D axis point into 2D
        x_3d = cen_3d + [[axis_range],[0],[0],[0]]
        y_3d = cen_3d + [[0],[axis_range],[0],[0]]
        z_3d = cen_3d + [[0],[0],[axis_range],[0]]
        x_2d = np.matmul(cam2d_3d, x_3d)
        y_2d = np.matmul(cam2d_3d, y_3d)
        z_2d = np.matmul(cam2d_3d, z_3d)
        
        #draw the axis on 2D
        draw.line((c_x, c_y, x_2d[0], x_2d[1]), fill=(255,255,0), width=5)
        draw.line((c_x, c_y, y_2d[0], y_2d[1]), fill=(0,255,0), width=5)
        draw.line((c_x, c_y, z_2d[0], z_2d[1]), fill=(0,0,255), width=5)
        
        #g13: draw the estimate pred obj
        for pti in pred:
            pti.transpose()
            pti_2d = np.matmul(cam_intrinsic, pti)
            #print('({0},{1})\n'.format(int(pti_2d[0]),int(pti_2d[1])))
            draw.point([int(pti_2d[0]),int(pti_2d[1])], fill=(255,255,0))
            
        
        #g13: show image
        #img.show()
        
        #save file under file 
        img_file_name = '{0}/batch{1}_pred_obj{2}_pic{3}.png'.format(vimg_dir, j, test_dataset.list_obj[j], which_item)
        img.save( img_file_name, "PNG" )
        img.close()
        
        # plot ground true ----------------------------
        img = Image.open(test_dataset.list_rgb[j])
        draw = ImageDraw.Draw(img) 
        draw.line((bbx[0],bbx[1], bbx[0], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1], bbx[0]+bbx[2], bbx[1]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1]+bbx[3], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0]+bbx[2],bbx[1], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)        
        target_r = np.resize(np.array(meta[which_item][k_idx]['cam_R_m2c']), (3, 3))                
        target_t = np.array(meta[which_item][k_idx]['cam_t_m2c'])
        target_t = target_t[np.newaxis, :]               
        cam_extrinsic_GT = np.concatenate((target_r, target_t.T), axis=1)
        
        
        #get center 3D
        cam2d_3d_GT = np.matmul(cam_intrinsic, cam_extrinsic_GT)
        cen_3d_GT = np.matmul(np.linalg.pinv(cam2d_3d_GT), [[c_x],[c_y],[1]])
        
        #transpose three 3D axis point into 2D
        x_3d = cen_3d_GT + [[axis_range],[0],[0],[0]]
        y_3d = cen_3d_GT + [[0],[axis_range],[0],[0]]
        z_3d = cen_3d_GT + [[0],[0],[axis_range],[0]]
        
        x_2d = np.matmul(cam2d_3d_GT, x_3d)
        y_2d = np.matmul(cam2d_3d_GT, y_3d)
        z_2d = np.matmul(cam2d_3d_GT, z_3d)

        #draw the axis on 2D
        draw.line((c_x, c_y, x_2d[0], x_2d[1]), fill=(255,255,0), width=5)
        draw.line((c_x, c_y, y_2d[0], y_2d[1]), fill=(0,255,0), width=5)
        draw.line((c_x, c_y, z_2d[0], z_2d[1]), fill=(0,0,255), width=5)
      
       
        print('pred:\n{0}\nGT:\n{1}\n'.format(cam_extrinsic,cam_extrinsic_GT))
        print('pred 3D:{0}\nGT 3D:{1}\n'.format(cen_3d, cen_3d_GT))
        img_file_name = '{0}/batch{1}_pred_obj{2}_pic{3}_gt.png'.format(vimg_dir, j, test_dataset.list_obj[j], which_item)
        img.save( img_file_name, "PNG" )
        img.close()
        meta_file.close()
    print('\nplot_result_img.py completed the task\n')
コード例 #13
0
ファイル: train_linemod.py プロジェクト: HuaWeitong/REDE
def main():
    if opt.dataset == 'linemod':
        opt.num_obj = 1
        opt.list_obj = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
        opt.occ_list_obj = [1, 5, 6, 8, 9, 10, 11, 12]
        opt.list_name = ['ape', 'benchvise', 'cam', 'can', 'cat', 'driller', 'duck', 'eggbox', 'glue', 'holepuncher', 'iron', 'lamp', 'phone']
        obj_name = opt.list_name[opt.list_obj.index(opt.obj_id)]
        opt.sym_list = [10, 11]
        opt.num_points = 500
        meta_file = open('{0}/models/models_info.yml'.format(opt.dataset_root), 'r')
        meta = yaml.load(meta_file)
        diameter = meta[opt.obj_id]['diameter'] / 1000.0 * 0.1
        if opt.render:
            opt.repeat_num = 1
        elif opt.fuse:
            opt.repeat_num = 1
        else:
            opt.repeat_num = 5
        writer = SummaryWriter('experiments/runs/linemod/{}{}'.format(obj_name, opt.experiment_name))
        opt.outf = 'trained_models/linemod/{}{}'.format(obj_name, opt.experiment_name)
        opt.log_dir = 'experiments/logs/linemod/{}{}'.format(obj_name, opt.experiment_name)
        if not os.path.exists(opt.outf):
            os.mkdir(opt.outf)
        if not os.path.exists(opt.log_dir):
            os.mkdir(opt.log_dir)
    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points = opt.num_points, num_vote = 9, num_obj = opt.num_obj)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_obj)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))
    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.lr = opt.lr_refine
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    dataset = PoseDataset_linemod('train', opt.num_points, opt.dataset_root, opt.real, opt.render, opt.fuse, opt.obj_id)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
    test_dataset = PoseDataset_linemod('test', opt.num_points, opt.dataset_root, True, False, False, opt.obj_id)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}'.format(len(dataset), len(test_dataset), opt.num_points))
    if opt.obj_id in opt.occ_list_obj:
        occ_test_dataset = PoseDataset_occ('test', opt.num_points, opt.occ_dataset_root, opt.obj_id)
        occtestdataloader = torch.utils.data.DataLoader(occ_test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
        print('length of the occ testing set: {}'.format(len(occ_test_dataset)))

    criterion = Loss(opt.num_points, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points, opt.sym_list)
    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()
    train_scalar = 0

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started'))
        train_count = 0
        train_loss_avg = 0.0
        train_loss = 0.0
        train_dis_avg = 0.0
        train_dis = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()
        for rep in range(opt.repeat_num):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
                if len(points.size()) == 2:
                    print('pass')
                    continue
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
                vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
                vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, opt.obj_id, target_r, target_t)
                loss = 10 * vertex_loss + pose_loss
                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, opt.obj_id)
                        dis.backward()
                else:
                    loss.backward()

                train_loss_avg += loss.item()
                train_loss += loss.item()
                train_dis_avg += dis.item()
                train_dis += dis.item()
                train_count += 1
                train_scalar += 1

                if train_count % opt.batch_size == 0:
                    logger.info('Train time {0} Epoch {1} Batch {2} Frame {3} Avg_loss:{4} Avg_diss:{5}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, int(train_count / opt.batch_size), train_count, train_loss_avg / opt.batch_size, train_dis_avg / opt.batch_size))
                    writer.add_scalar('linemod training loss', train_loss_avg / opt.batch_size, train_scalar)
                    writer.add_scalar('linemod training dis', train_dis_avg / opt.batch_size, train_scalar)
                    optimizer.step()
                    optimizer.zero_grad()
                    train_loss_avg = 0
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(refiner.state_dict(), '{0}/pose_refine_model_current.pth'.format(opt.outf))
                    else:
                        torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(opt.outf))

        print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch))
        train_loss = train_loss / train_count
        train_dis = train_dis / train_count
        logger.info('Train time {0} Epoch {1} TRAIN FINISH Avg loss: {2} Avg dis: {3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, train_loss, train_dis))

        logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
        test_loss = 0.0
        test_vertex_loss = 0.0
        test_pose_loss = 0.0
        test_dis = 0.0
        test_count = 0
        success_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
            if len(points.size()) == 2:
                logger.info('Test time {0} Lost detection!'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time))))
                continue
            points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
            vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
            vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, opt.obj_id, target_r, target_t)
            loss = 10 * vertex_loss + pose_loss
            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, opt.obj_id)

            test_loss += loss.item()
            test_vertex_loss += vertex_loss.item()
            test_pose_loss += pose_loss.item()
            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} loss:{2} dis:{3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, loss, dis))
            if dis.item() < diameter:
                success_count += 1
            test_count += 1

        test_loss = test_loss / test_count
        test_vertex_loss = test_vertex_loss / test_count
        test_pose_loss = test_pose_loss / test_count
        test_dis = test_dis / test_count
        success_rate = float(success_count) / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg loss: {2} Avg dis: {3} Success rate: {4}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_loss, test_dis, success_rate))
        writer.add_scalar('linemod test loss', test_loss, epoch)
        writer.add_scalar('linemod test vertex loss', test_vertex_loss, epoch)
        writer.add_scalar('linemod test pose loss', test_pose_loss, epoch)
        writer.add_scalar('linemod test dis', test_dis, epoch)
        writer.add_scalar('linemod success rate', success_rate, epoch)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        if test_dis <= best_test:
            best_test = test_dis
        if opt.refine_start:
            torch.save(refiner.state_dict(), '{0}/pose_refine_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
        else:
            torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
        print(epoch, '>>>>>>>>----------MODEL SAVED---------<<<<<<<<')

        if opt.obj_id in opt.occ_list_obj:
            logger = setup_logger('epoch%d_occ_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_occ_test_log.txt' % epoch))
            logger.info('Occ test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
            occ_test_dis = 0.0
            occ_test_count = 0
            occ_success_count = 0
            estimator.eval()
            refiner.eval()

            for j, data in enumerate(occtestdataloader, 0):
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
                if len(points.size()) == 2:
                    logger.info('Occ test time {0} Lost detection!'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time))))
                    continue
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
                vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
                vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, opt.obj_id, target_r, target_t)
                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, opt.obj_id)

                occ_test_dis += dis.item()
                logger.info('Occ test time {0} Test Frame No.{1} dis:{2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), occ_test_count, dis))
                if dis.item() < diameter:
                    occ_success_count += 1
                occ_test_count += 1

            occ_test_dis = occ_test_dis / occ_test_count
            occ_success_rate = float(occ_success_count) / occ_test_count
            logger.info('Occ test time {0} Epoch {1} TEST FINISH Avg dis: {2} Success rate: {3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, occ_test_dis, occ_success_rate))
            writer.add_scalar('occ test dis', occ_test_dis, epoch)
            writer.add_scalar('occ success rate', occ_success_rate, epoch)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.lr = opt.lr_refine
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
            print('>>>>>>>>----------Refine started---------<<<<<<<<')

    writer.close()
コード例 #14
0
def main():
    if opt.dataset == 'ycb':
        opt.num_obj = 21
        opt.sym_list = [12, 15, 18, 19, 20]
        opt.num_points = 1000
        writer = SummaryWriter('experiments/runs/ycb/{0}'.format(opt.experiment_name))
        opt.outf = 'trained_models/ycb/{0}'.format(opt.experiment_name)
        opt.log_dir = 'experiments/logs/ycb/{0}'.format(opt.experiment_name)
        opt.repeat_num = 1
        if not os.path.exists(opt.outf):
            os.mkdir(opt.outf)
        if not os.path.exists(opt.log_dir):
            os.mkdir(opt.log_dir)
    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points = opt.num_points, num_vote = 9, num_obj = opt.num_obj)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_obj)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))
    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.lr = opt.lr_refine
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
    test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}'.format(len(dataset), len(test_dataset), opt.num_points))

    criterion = Loss(opt.num_points, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points, opt.sym_list)
    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()
    train_scalar = 0

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started'))
        train_count = 0
        train_loss_avg = 0.0
        train_loss = 0.0
        train_dis_avg = 0.0
        train_dis = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()
        for rep in range(opt.repeat_num):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
                vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
                vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, idx, target_r, target_t)
                loss = 10 * vertex_loss + pose_loss
                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, idx)
                        dis.backward()
                else:
                    loss.backward()
                train_loss_avg += loss.item()
                train_loss += loss.item()
                train_dis_avg += dis.item()
                train_dis += dis.item()
                train_count += 1
                train_scalar += 1

                if train_count % opt.batch_size == 0:
                    logger.info('Train time {0} Epoch {1} Batch {2} Frame {3} Avg_loss:{4} Avg_diss:{5}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, int(train_count / opt.batch_size), train_count, train_loss_avg / opt.batch_size, train_dis_avg / opt.batch_size))
                    writer.add_scalar('ycb training loss', train_loss_avg / opt.batch_size, train_scalar)
                    writer.add_scalar('ycb training dis', train_dis_avg / opt.batch_size, train_scalar)
                    optimizer.step()
                    optimizer.zero_grad()
                    train_loss_avg = 0
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(refiner.state_dict(), '{0}/pose_refine_model_current.pth'.format(opt.outf))
                    else:
                        torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(opt.outf))

        print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch))
        train_loss = train_loss / train_count
        train_dis = train_dis / train_count
        logger.info('Train time {0} Epoch {1} TRAIN FINISH Avg loss: {2} Avg dis: {3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, train_loss, train_dis))

        logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
        test_loss = 0.0
        test_vertex_loss = 0.0
        test_pose_loss = 0.0
        test_dis = 0.0
        test_count = 0
        success_count = 0
        estimator.eval()
        refiner.eval()
        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
            points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
            vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
            vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, idx, target_r, target_t)
            loss = 10 * vertex_loss + pose_loss
            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, idx)
            test_loss += loss.item()
            test_vertex_loss += vertex_loss.item()
            test_pose_loss += pose_loss.item()
            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} loss:{2} dis:{3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, loss, dis))
            test_count += 1
            if dis.item() < 0.02:
                success_count += 1

        test_loss = test_loss / test_count
        test_vertex_loss = test_vertex_loss / test_count
        test_pose_loss = test_pose_loss / test_count
        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg loss: {2} Avg dis: {3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_loss, test_dis))
        logger.info('Success rate: {}'.format(float(success_count) / test_count))
        writer.add_scalar('ycb test loss', test_loss, epoch)
        writer.add_scalar('ycb test vertex loss', test_vertex_loss, epoch)
        writer.add_scalar('ycb test pose loss', test_pose_loss, epoch)
        writer.add_scalar('ycb test dis', test_dis, epoch)
        writer.add_scalar('ycb success rate', float(success_count) / test_count, epoch)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        if test_dis <= best_test:
            best_test = test_dis
        if opt.refine_start:
            torch.save(refiner.state_dict(), '{0}/pose_refine_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
        else:
            torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
        print(epoch, '>>>>>>>>----------MODEL SAVED---------<<<<<<<<')

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.lr = opt.lr_refine
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
            print('>>>>>>>>----------Refine started---------<<<<<<<<')

    writer.close()
コード例 #15
0
ファイル: train.py プロジェクト: chaddy1004/ARL_AFF_POSE
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    if opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        output_results = 'check_linemod.txt'
        opt.repeat_epoch = 20

    elif opt.dataset == 'ycb':
        opt.num_objects = 21  #number of object classes in the dataset
        opt.num_points = 1000  #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb'  #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb'  #folder to save logs
        opt.repeat_epoch = 1  #number of repeat times for one epoch training

    elif opt.dataset == 'ycb-syn':
        opt.num_objects = 31  # number of object classes in the dataset
        opt.num_points = 1000  # number of points on the input pointcloud
        opt.dataset_root = '/data/Akeaveny/Datasets/ycb_syn'
        opt.outf = 'trained_models/ycb_syn/ycb_syn2'  # folder to save trained models
        opt.log_dir = 'experiments/logs/ycb_syn/ycb_syn2'  # folder to save logs
        output_results = 'check_ycb_syn.txt'

        opt.w = 0.05
        opt.refine_margin = 0.01

    elif opt.dataset == 'arl':
        opt.num_objects = 10  # number of object classes in the dataset
        opt.num_points = 1000  # number of points on the input pointcloud
        opt.dataset_root = '/data/Akeaveny/Datasets/arl_dataset'
        opt.outf = 'trained_models/arl/clutter/arl_finetune_syn_2'  # folder to save trained models
        opt.log_dir = '/home/akeaveny/catkin_ws/src/object-rpe-ak/DenseFusion/experiments/logs/arl/clutter/arl_finetune_syn_2'  # folder to save logs
        output_results = 'check_arl_syn.txt'

        opt.nepoch = 750

        opt.w = 0.05
        opt.refine_margin = 0.0045

        # TODO
        opt.repeat_epoch = 20
        opt.start_epoch = 0
        opt.resume_posenet = 'pose_model_1_0.012397416144377301.pth'
        opt.resume_refinenet = 'pose_refine_model_153_0.004032851301599294.pth'

    elif opt.dataset == 'arl1':
        opt.num_objects = 5  # number of object classes in the dataset
        opt.num_points = 1000  # number of points on the input pointcloud
        opt.dataset_root = '/data/Akeaveny/Datasets/arl_dataset'
        opt.outf = 'trained_models/arl1/clutter/arl_real_2'  # folder to save trained models
        opt.log_dir = '/home/akeaveny/catkin_ws/src/object-rpe-ak/DenseFusion/experiments/logs/arl1/clutter/arl_real_2'  # folder to save logs
        output_results = 'check_arl_syn.txt'

        opt.nepoch = 750

        opt.w = 0.05
        opt.refine_margin = 0.015

        # opt.start_epoch = 120
        # opt.resume_posenet = 'pose_model_current.pth'
        # opt.resume_refinenet = 'pose_refine_model_115_0.008727498716640046.pth'

    elif opt.dataset == 'elevator':
        opt.num_objects = 1  # number of object classes in the dataset
        opt.num_points = 1000  # number of points on the input pointcloud
        opt.dataset_root = '/data/Akeaveny/Datasets/elevator_dataset'
        opt.outf = 'trained_models/elevator/elevator_2'  # folder to save trained models
        opt.log_dir = '/home/akeaveny/catkin_ws/src/object-rpe-ak/DenseFusion/experiments/logs/elevator/elevator_2'  # folder to save logs
        output_results = 'check_arl_syn.txt'

        opt.nepoch = 750

        opt.w = 0.05
        opt.refine_margin = 0.015

        opt.nepoch = 750

        opt.w = 0.05
        opt.refine_margin = 0.015

        # TODO
        opt.repeat_epoch = 40
        # opt.start_epoch = 47
        # opt.resume_posenet = 'pose_model_current.pth'
        # opt.resume_refinenet = 'pose_refine_model_46_0.007581770288279472.pth'

    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = False
        opt.decay_start = False
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    elif opt.dataset == 'ycb-syn':
        dataset = PoseDataset_ycb_syn('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    elif opt.dataset == 'arl':
        dataset = PoseDataset_arl('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
    elif opt.dataset == 'arl1':
        dataset = PoseDataset_arl1('train', opt.num_points, True,
                                   opt.dataset_root, opt.noise_trans,
                                   opt.refine_start)
    elif opt.dataset == 'elevator':
        dataset = PoseDataset_elevator('train', opt.num_points, True,
                                       opt.dataset_root, opt.noise_trans,
                                       opt.refine_start)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)

    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    elif opt.dataset == 'ycb-syn':
        test_dataset = PoseDataset_ycb_syn('test', opt.num_points, True,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    elif opt.dataset == 'arl':
        test_dataset = PoseDataset_arl('test', opt.num_points, True,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'arl1':
        test_dataset = PoseDataset_arl1('test', opt.num_points, True,
                                        opt.dataset_root, 0.0,
                                        opt.refine_start)
    elif opt.dataset == 'elevator':
        test_dataset = PoseDataset_elevator('test', opt.num_points, True,
                                            opt.dataset_root, 0.0,
                                            opt.refine_start)

    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    print(
        '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
        .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    ######################
    ######################

    # TODO (ak): set up tensor board
    # if not os.path.exists(opt.log_dir):
    #     os.makedirs(opt.log_dir)
    #
    # writer = SummaryWriter(opt.log_dir)

    ######################
    ######################

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%d' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):

            ##################
            # train
            ##################

            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data

                # TODO: txt file
                # fw = open(test_folder + output_results, 'w')
                # fw.write('Points\n{0}\n\nchoose\n{1}\n\nimg\n{2}\n\ntarget\n{3}\n\nmodel_points\n{4}'.format(points, choose, img, target, model_points))
                # fw.close()

                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                pred_r, pred_t, pred_c, emb = estimator(
                    img, points, choose, idx)
                loss, dis, new_points, new_target = criterion(
                    pred_r, pred_t, pred_c, target, model_points, idx, points,
                    opt.w, opt.refine_start)

                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(
                            pred_r, pred_t, new_target, model_points, idx,
                            new_points)
                        dis.backward()
                else:
                    loss.backward()

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info(
                        'Train time {} Epoch {} Batch {} Frame {}/{} Avg_dis: {:.2f} [cm]'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, len(dataset.list),
                            train_dis_avg / opt.batch_size * 100))
                    optimizer.step()
                    optimizer.zero_grad()

                    # TODO: tensorboard
                    # if train_count != 0 and train_count % 250 == 0:
                    #     scalar_info = {'loss': loss.item(),
                    #                    'dis': train_dis_avg / opt.batch_size}
                    #     for key, val in scalar_info.items():
                    #         writer.add_scalar(key, val, train_count)

                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(
                            refiner.state_dict(),
                            '{0}/pose_refine_model_current.pth'.format(
                                opt.outf))
                    else:
                        torch.save(
                            estimator.state_dict(),
                            '{0}/pose_model_current.pth'.format(opt.outf))

                    # TODO: tensorboard
                    # scalar_info = {'loss': loss.item(),
                    #                'dis': dis.item()}
                    # for key, val in scalar_info.items():
                    #     writer.add_scalar(key, val, train_count)

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%d_test' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c,
                                                       target, model_points,
                                                       idx, points, opt.w,
                                                       opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(
                        pred_r, pred_t, new_target, model_points, idx,
                        new_points)

            test_dis += dis.item()
            logger.info('Test time {} Test Frame No.{} dis: {} [cm]'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), test_count,
                dis * 100))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info(
            'Test time {} Epoch {} TEST FINISH Avg dis: {} [cm]'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), epoch,
                test_dis * 100))

        # TODO: tensorboard
        # scalar_info = {'test dis': test_dis}
        # for key, val in scalar_info.items():
        #     writer.add_scalar(key, val, train_count)

        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            if opt.dataset == 'ycb':
                dataset = PoseDataset_ycb('train', opt.num_points, True,
                                          opt.dataset_root, opt.noise_trans,
                                          opt.refine_start)
            elif opt.dataset == 'linemod':
                dataset = PoseDataset_linemod('train', opt.num_points, True,
                                              opt.dataset_root,
                                              opt.noise_trans,
                                              opt.refine_start)
            elif opt.dataset == 'ycb-syn':
                dataset = PoseDataset_ycb_syn('train', opt.num_points, True,
                                              opt.dataset_root,
                                              opt.noise_trans,
                                              opt.refine_start)
            elif opt.dataset == 'arl':
                dataset = PoseDataset_arl('train', opt.num_points, True,
                                          opt.dataset_root, opt.noise_trans,
                                          opt.refine_start)
            elif opt.dataset == 'arl1':
                dataset = PoseDataset_arl1('train', opt.num_points, True,
                                           opt.dataset_root, opt.noise_trans,
                                           opt.refine_start)
            elif opt.dataset == 'elevator':
                dataset = PoseDataset_elevator('train', opt.num_points, True,
                                               opt.dataset_root,
                                               opt.noise_trans,
                                               opt.refine_start)

            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=opt.workers)

            if opt.dataset == 'ycb':
                test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                               opt.dataset_root, 0.0,
                                               opt.refine_start)
            elif opt.dataset == 'linemod':
                test_dataset = PoseDataset_linemod('test', opt.num_points,
                                                   False, opt.dataset_root,
                                                   0.0, opt.refine_start)
            elif opt.dataset == 'ycb-syn':
                test_dataset = PoseDataset_ycb_syn('test', opt.num_points,
                                                   True, opt.dataset_root, 0.0,
                                                   opt.refine_start)
            elif opt.dataset == 'arl':
                test_dataset = PoseDataset_arl('test', opt.num_points, True,
                                               opt.dataset_root, 0.0,
                                               opt.refine_start)
            elif opt.dataset == 'arl1':
                test_dataset = PoseDataset_arl1('test', opt.num_points, True,
                                                opt.dataset_root, 0.0,
                                                opt.refine_start)
            elif opt.dataset == 'elevator':
                test_dataset = PoseDataset_elevator('test', opt.num_points,
                                                    True, opt.dataset_root,
                                                    0.0, opt.refine_start)

            testdataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                num_workers=opt.workers)

            opt.sym_list = dataset.get_sym_list()
            opt.num_points_mesh = dataset.get_num_points_mesh()

            print(
                '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
                .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                        opt.sym_list))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
コード例 #16
0
ファイル: train.py プロジェクト: youjinqing123/6DDenseFusion
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    opt.num_objects = 3
    opt.num_points = 500
    opt.outf = 'trained_models'
    opt.log_dir = 'experiments/logs'
    opt.repeat_epoch = 20

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    dataset = PoseDataset('train', opt.num_points, True, opt.dataset_root,
                          opt.noise_trans, opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)

    test_dataset = PoseDataset('test', opt.num_points, False, opt.dataset_root,
                               0.0, opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    print(
        '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}'
        .format(len(dataset), len(test_dataset), opt.num_points_mesh))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%d' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            estimator.eval()  # affects dropout and batch normalization
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data
                #points        ->torch.Size([500, 3])  ->在crop出来的像素区域随机选取500个点,利用相机内参结合深度值算出来的点云cloud
                #choose        ->torch.Size([1, 500])
                #img           ->torch.Size([3, 80, 80])
                #target        ->torch.Size([500, 3])  ->真实模型上随机选取的mesh点进行ground truth pose变换后得到的点
                #model_points  ->torch.Size([500, 3])  ->真实模型上随机选取的mesh点在进行pose变换前的点
                #idx           ->torch.Size([1])
                #tensor([4], device='cuda:0')
                #img和points对应rgb和点云信息,需要在网络内部fusion
                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                pred_r, pred_t, pred_c, emb = estimator(
                    img, points, choose, idx)
                loss, dis, new_points, new_target = criterion(
                    pred_r, pred_t, pred_c, target, model_points, idx, points,
                    opt.w, opt.refine_start)

                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(
                            pred_r, pred_t, new_target, model_points, idx,
                            new_points)
                        dis.backward()
                else:
                    loss.backward()

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info(
                        'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, train_dis_avg / opt.batch_size))
                    optimizer.step()
                    optimizer.zero_grad()
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(
                            refiner.state_dict(),
                            '{0}/pose_refine_model_current.pth'.format(
                                opt.outf))
                    else:
                        torch.save(
                            estimator.state_dict(),
                            '{0}/pose_model_current.pth'.format(opt.outf))

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%d_test' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c,
                                                       target, model_points,
                                                       idx, points, opt.w,
                                                       opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(
                        pred_r, pred_t, new_target, model_points, idx,
                        new_points)

            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), test_count,
                dis))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)),
            epoch, test_dis))
        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            dataset = PoseDataset('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=opt.workers)

            test_dataset = PoseDataset('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
            testdataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                num_workers=opt.workers)

            opt.sym_list = dataset.get_sym_list()
            opt.num_points_mesh = dataset.get_num_points_mesh()

            print(
                '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}'
                .format(len(dataset), len(test_dataset), opt.num_points_mesh))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)