Example #1
0
def test(epoch, cfg, data_loader, model, obj_vtx, obj_info, criterions):

    model.eval()
    Eval = Evaluation(cfg.dataset, obj_info, obj_vtx)
    if 'trans' in cfg.pytorch.task.lower():
        Eval_trans = Evaluation(cfg.dataset, obj_info, obj_vtx)

    if not cfg.test.ignore_cache_file:
        est_cache_file = cfg.test.cache_file
        # gt_cache_file = cfg.test.cache_file.replace('pose_est', 'pose_gt')
        gt_cache_file = cfg.test.cache_file.replace('_est', '_gt')
        if os.path.exists(est_cache_file) and os.path.exists(gt_cache_file):
            Eval.pose_est_all = np.load(est_cache_file,
                                        allow_pickle=True).tolist()
            Eval.pose_gt_all = np.load(gt_cache_file,
                                       allow_pickle=True).tolist()
            fig_save_path = os.path.join(cfg.pytorch.save_path, str(epoch))
            mkdir_p(fig_save_path)
            if 'all' in cfg.test.test_mode.lower():
                Eval.evaluate_pose()
                Eval.evaluate_pose_add(fig_save_path)
                Eval.evaluate_pose_arp_2d(fig_save_path)
            elif 'pose' in cfg.test.test_mode.lower():
                Eval.evaluate_pose()
            elif 'add' in cfg.test.test_mode.lower():
                Eval.evaluate_pose_add(fig_save_path)
            elif 'arp' in cfg.test.test_mode.lower():
                Eval.evaluate_pose_arp_2d(fig_save_path)
            else:
                raise Exception("Wrong test mode: {}".format(
                    cfg.test.test_mode))

            return None, None

        else:
            logger.info("test cache file {} and {} not exist!".format(
                est_cache_file, gt_cache_file))
            userAns = input("Generating cache file from model [Y(y)/N(n)]:")
            if userAns.lower() == 'n':
                sys.exit(0)
            else:
                logger.info("Generating test cache file!")

    preds = {}
    Loss = AverageMeter()
    Loss_rot = AverageMeter()
    Loss_trans = AverageMeter()
    num_iters = len(data_loader)
    bar = Bar('{}'.format(cfg.pytorch.exp_id[-60:]), max=num_iters)

    time_monitor = False
    vis_dir = os.path.join(cfg.pytorch.save_path, 'test_vis_{}'.format(epoch))
    if not os.path.exists(vis_dir):
        os.makedirs(vis_dir)
    for i, (obj, obj_id, inp, pose, c_box, s_box, box,
            trans_local) in enumerate(data_loader):
        if cfg.pytorch.gpu > -1:
            inp_var = inp.cuda(cfg.pytorch.gpu, async=True).float()
        else:
            inp_var = inp.float()

        bs = len(inp)
        # forward propagation
        T_begin = time.time()
        pred_rot, pred_trans = model(inp_var)
        T_end = time.time() - T_begin
        if time_monitor:
            logger.info(
                "time for a batch forward of resnet model is {}".format(T_end))

        if i % cfg.test.disp_interval == 0:
            # input image
            inp_rgb = (inp[0].cpu().numpy().copy() *
                       255)[[2, 1, 0], :, :].astype(np.uint8)
            cfg.writer.add_image('input_image', inp_rgb, i)
            cv2.imwrite(os.path.join(vis_dir, '{}_inp.png'.format(i)),
                        inp_rgb.transpose(1, 2, 0)[:, :, ::-1])
            if 'rot' in cfg.pytorch.task.lower():
                # coordinates map
                pred_coor = pred_rot[0, 0:3].data.cpu().numpy().copy()
                pred_coor[0] = im_norm_255(pred_coor[0])
                pred_coor[1] = im_norm_255(pred_coor[1])
                pred_coor[2] = im_norm_255(pred_coor[2])
                pred_coor = np.asarray(pred_coor, dtype=np.uint8)
                cfg.writer.add_image('test_coor_x_pred',
                                     np.expand_dims(pred_coor[0], axis=0), i)
                cfg.writer.add_image('test_coor_y_pred',
                                     np.expand_dims(pred_coor[1], axis=0), i)
                cfg.writer.add_image('test_coor_z_pred',
                                     np.expand_dims(pred_coor[2], axis=0), i)
                # gt_coor = target[0, 0:3].data.cpu().numpy().copy()
                # gt_coor[0] = im_norm_255(gt_coor[0])
                # gt_coor[1] = im_norm_255(gt_coor[1])
                # gt_coor[2] = im_norm_255(gt_coor[2])
                # gt_coor = np.asarray(gt_coor, dtype=np.uint8)
                # cfg.writer.add_image('test_coor_x_gt', np.expand_dims(gt_coor[0], axis=0), i)
                # cfg.writer.add_image('test_coor_y_gt', np.expand_dims(gt_coor[1], axis=0), i)
                # cfg.writer.add_image('test_coor_z_gt', np.expand_dims(gt_coor[2], axis=0), i)
                # confidence map
                pred_conf = pred_rot[0, 3].data.cpu().numpy().copy()
                pred_conf = (im_norm_255(pred_conf)).astype(np.uint8)
                cfg.writer.add_image('test_conf_pred',
                                     np.expand_dims(pred_conf, axis=0), i)
                # gt_conf = target[0, 3].data.cpu().numpy().copy()
                # cfg.writer.add_image('test_conf_gt', np.expand_dims(gt_conf, axis=0), i)
            if 'trans' in cfg.pytorch.task.lower():
                pred_trans_ = pred_trans[0].data.cpu().numpy().copy()
                gt_trans_ = trans_local[0].data.cpu().numpy().copy()
                cfg.writer.add_scalar('test_trans_x_gt', gt_trans_[0],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('test_trans_y_gt', gt_trans_[1],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('test_trans_z_gt', gt_trans_[2],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('test_trans_x_pred', pred_trans_[0],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('test_trans_y_pred', pred_trans_[1],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('test_trans_z_pred', pred_trans_[2],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('test_trans_x_err',
                                      np.abs(pred_trans_[0] - gt_trans_[0]),
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('test_trans_y_err',
                                      np.abs(pred_trans_[1] - gt_trans_[1]),
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('test_trans_z_err',
                                      np.abs(pred_trans_[2] - gt_trans_[2]),
                                      i + (epoch - 1) * num_iters)

        if 'rot' in cfg.pytorch.task.lower():
            pred_coor = pred_rot[:, 0:3].data.cpu().numpy().copy()
            pred_conf = pred_rot[:, 3].data.cpu().numpy().copy()
        else:
            pred_coor = np.zeros(bs)
            pred_conf = np.zeros(bs)

        if 'trans' in cfg.pytorch.task.lower():
            pred_trans = pred_trans.data.cpu().numpy().copy()
        else:
            pred_trans = np.zeros(bs)

        col = list(
            zip(obj, obj_id.numpy(), pred_coor, pred_conf, pred_trans,
                pose.numpy(), c_box.numpy(), s_box.numpy(), box.numpy()))
        for idx in range(len(col)):
            obj_, obj_id_, pred_coor_, pred_conf_, pred_trans_, pose_gt, c_box_, s_box_, box_ = col[
                idx]
            T_begin = time.time()
            if 'rot' in cfg.pytorch.task.lower():
                # building 2D-3D correspondences
                pred_coor_ = pred_coor_.transpose(1, 2, 0)
                pred_coor_[:, :, 0] = pred_coor_[:, :, 0] * abs(
                    obj_info[obj_id_]['min_x'])
                pred_coor_[:, :, 1] = pred_coor_[:, :, 1] * abs(
                    obj_info[obj_id_]['min_y'])
                pred_coor_[:, :, 2] = pred_coor_[:, :, 2] * abs(
                    obj_info[obj_id_]['min_z'])
                pred_coor_ = pred_coor_.tolist()
                eroMask = False
                if eroMask:
                    kernel = np.ones((3, 3), np.uint8)
                    pred_conf_ = cv2.erode(pred_conf_, kernel)
                pred_conf_ = (pred_conf_ - pred_conf_.min()) / (
                    pred_conf_.max() - pred_conf_.min())
                pred_conf_ = pred_conf_.tolist()

                select_pts_2d = []
                select_pts_3d = []
                c_w = int(c_box_[0])
                c_h = int(c_box_[1])
                s = int(s_box_)
                w_begin = c_w - s / 2.
                h_begin = c_h - s / 2.
                w_unit = s * 1.0 / cfg.dataiter.out_res
                h_unit = s * 1.0 / cfg.dataiter.out_res

                min_x = 0.001 * abs(obj_info[obj_id_]['min_x'])
                min_y = 0.001 * abs(obj_info[obj_id_]['min_y'])
                min_z = 0.001 * abs(obj_info[obj_id_]['min_z'])
                for x in range(cfg.dataiter.out_res):
                    for y in range(cfg.dataiter.out_res):
                        if pred_conf_[x][y] < cfg.test.mask_threshold:
                            continue
                        if abs(pred_coor_[x][y][0]) < min_x  and abs(pred_coor_[x][y][1]) < min_y  and \
                            abs(pred_coor_[x][y][2]) < min_z:
                            continue
                        select_pts_2d.append(
                            [w_begin + y * w_unit, h_begin + x * h_unit])
                        select_pts_3d.append(pred_coor_[x][y])

                model_points = np.asarray(select_pts_3d, dtype=np.float32)
                image_points = np.asarray(select_pts_2d, dtype=np.float32)

            if 'trans' in cfg.pytorch.task.lower():
                # compute T from translation head
                ratio_delta_c = pred_trans_[:2]
                ratio_depth = pred_trans_[2]
                pred_depth = ratio_depth * (cfg.dataiter.out_res / s_box_)
                pred_c = ratio_delta_c * box_[2:] + c_box_
                pred_x = (pred_c[0] - cfg.dataset.camera_matrix[0, 2]
                          ) * pred_depth / cfg.dataset.camera_matrix[0, 0]
                pred_y = (pred_c[1] - cfg.dataset.camera_matrix[1, 2]
                          ) * pred_depth / cfg.dataset.camera_matrix[1, 1]
                T_vector_trans = np.asarray([pred_x, pred_y, pred_depth])
                pose_est_trans = np.concatenate(
                    (np.eye(3), np.asarray((T_vector_trans).reshape(3, 1))),
                    axis=1)

            try:
                if 'rot' in cfg.pytorch.task.lower():
                    dist_coeffs = np.zeros(
                        (4, 1))  # Assuming no lens distortion
                    if cfg.test.pnp == 'iterPnP':  # iterative PnP algorithm
                        success, R_vector, T_vector = cv2.solvePnP(
                            model_points,
                            image_points,
                            cfg.dataset.camera_matrix,
                            dist_coeffs,
                            flags=cv2.SOLVEPNP_ITERATIVE)
                    elif cfg.test.pnp == 'ransac':  # ransac algorithm
                        _, R_vector, T_vector, inliers = cv2.solvePnPRansac(
                            model_points,
                            image_points,
                            cfg.dataset.camera_matrix,
                            dist_coeffs,
                            flags=cv2.SOLVEPNP_EPNP)
                    else:
                        raise NotImplementedError(
                            "Not support PnP algorithm: {}".format(
                                cfg.test.pnp))
                    R_matrix = cv2.Rodrigues(R_vector, jacobian=0)[0]
                    pose_est = np.concatenate(
                        (R_matrix, np.asarray(T_vector).reshape(3, 1)), axis=1)
                    if 'trans' in cfg.pytorch.task.lower():
                        pose_est_trans = np.concatenate(
                            (R_matrix,
                             np.asarray((T_vector_trans).reshape(3, 1))),
                            axis=1)
                    Eval.pose_est_all[obj_].append(pose_est)
                    Eval.pose_gt_all[obj_].append(pose_gt)
                    Eval.num[obj_] += 1
                    Eval.numAll += 1
                if 'trans' in cfg.pytorch.task.lower():
                    Eval_trans.pose_est_all[obj_].append(pose_est_trans)
                    Eval_trans.pose_gt_all[obj_].append(pose_gt)
                    Eval_trans.num[obj_] += 1
                    Eval_trans.numAll += 1
            except:
                Eval.num[obj_] += 1
                Eval.numAll += 1
                if 'trans' in cfg.pytorch.task.lower():
                    Eval_trans.num[obj_] += 1
                    Eval_trans.numAll += 1
                logger.info('error in solve PnP or Ransac')

            T_end = time.time() - T_begin
            if time_monitor:
                logger.info(
                    "time spend on PnP+RANSAC for one image is {}".format(
                        T_end))

        Bar.suffix = 'test Epoch: [{0}][{1}/{2}]| Total: {total:} | ETA: {eta:} | Loss {loss.avg:.4f} | Loss_rot {loss_rot.avg:.4f} | Loss_trans {loss_trans.avg:.4f}'.format(
            epoch,
            i,
            num_iters,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=Loss,
            loss_rot=Loss_rot,
            loss_trans=Loss_trans)
        bar.next()

    epoch_save_path = os.path.join(cfg.pytorch.save_path, str(epoch))
    if not os.path.exists(epoch_save_path):
        os.makedirs(epoch_save_path)
    if 'rot' in cfg.pytorch.task.lower():
        logger.info("{} Evaluate of Rotation Branch of Epoch {} {}".format(
            '-' * 40, epoch, '-' * 40))
        preds['poseGT'] = Eval.pose_gt_all
        preds['poseEst'] = Eval.pose_est_all
        if cfg.pytorch.test:
            np.save(os.path.join(epoch_save_path, 'pose_est_all_test.npy'),
                    Eval.pose_est_all)
            np.save(os.path.join(epoch_save_path, 'pose_gt_all_test.npy'),
                    Eval.pose_gt_all)
        else:
            np.save(
                os.path.join(epoch_save_path,
                             'pose_est_all_epoch{}.npy'.format(epoch)),
                Eval.pose_est_all)
            np.save(
                os.path.join(epoch_save_path,
                             'pose_gt_all_epoch{}.npy'.format(epoch)),
                Eval.pose_gt_all)
        # evaluation
        if 'all' in cfg.test.test_mode.lower():
            Eval.evaluate_pose()
            Eval.evaluate_pose_add(epoch_save_path)
            Eval.evaluate_pose_arp_2d(epoch_save_path)
        else:
            if 'pose' in cfg.test.test_mode.lower():
                Eval.evaluate_pose()
            if 'add' in cfg.test.test_mode.lower():
                Eval.evaluate_pose_add(epoch_save_path)
            if 'arp' in cfg.test.test_mode.lower():
                Eval.evaluate_pose_arp_2d(epoch_save_path)

    if 'trans' in cfg.pytorch.task.lower():
        logger.info("{} Evaluate of Translation Branch of Epoch {} {}".format(
            '-' * 40, epoch, '-' * 40))
        preds['poseGT'] = Eval_trans.pose_gt_all
        preds['poseEst'] = Eval_trans.pose_est_all
        if cfg.pytorch.test:
            np.save(
                os.path.join(epoch_save_path, 'pose_est_all_test_trans.npy'),
                Eval_trans.pose_est_all)
            np.save(
                os.path.join(epoch_save_path, 'pose_gt_all_test_trans.npy'),
                Eval_trans.pose_gt_all)
        else:
            np.save(
                os.path.join(epoch_save_path,
                             'pose_est_all_trans_epoch{}.npy'.format(epoch)),
                Eval_trans.pose_est_all)
            np.save(
                os.path.join(epoch_save_path,
                             'pose_gt_all_trans_epoch{}.npy'.format(epoch)),
                Eval_trans.pose_gt_all)
        # evaluation
        if 'all' in cfg.test.test_mode.lower():
            Eval_trans.evaluate_pose()
            Eval_trans.evaluate_pose_add(epoch_save_path)
            Eval_trans.evaluate_pose_arp_2d(epoch_save_path)
        else:
            if 'pose' in cfg.test.test_mode.lower():
                Eval_trans.evaluate_pose()
            if 'add' in cfg.test.test_mode.lower():
                Eval_trans.evaluate_pose_add(epoch_save_path)
            if 'arp' in cfg.test.test_mode.lower():
                Eval_trans.evaluate_pose_arp_2d(epoch_save_path)

    bar.finish()
    return {
        'Loss': Loss.avg,
        'Loss_rot': Loss_rot.avg,
        'Loss_trans': Loss_trans.avg
    }, preds
Example #2
0
 def __init__(self, cfg, split):
     self.cfg = cfg
     self.split = split
     self.infos = self.load_lm_model_info(ref.lm_model_info_pth)
     self.cam_K = ref.K
     logger.info('==> initializing {} {} data.'.format(
         cfg.dataset.name, split))
     # load dataset
     annot = []
     if split == 'test':
         cache_dir = os.path.join(ref.cache_dir, 'test')
         if not os.path.exists(cache_dir):
             os.makedirs(cache_dir)
         for obj in tqdm(self.cfg.dataset.classes):
             cache_pth = os.path.join(cache_dir, '{}.npy'.format(obj))
             if not os.path.exists(cache_pth):
                 annot_cache = []
                 rgb_pths = glob(
                     os.path.join(ref.lm_test_dir, obj, '*-color.png'))
                 for rgb_pth in tqdm(rgb_pths):
                     item = self.col_test_item(rgb_pth)
                     item['obj'] = obj
                     annot_cache.append(item)
                 np.save(cache_pth, annot_cache)
             annot.extend(np.load(cache_pth, allow_pickle=True).tolist())
         self.num = len(annot)
         self.annot = annot
         logger.info('load {} test samples.'.format(self.num))
     elif split == 'train':
         if 'real' in self.cfg.dataset.img_type:
             cache_dir = os.path.join(ref.cache_dir, 'train/real')
             if not os.path.exists(cache_dir):
                 os.makedirs(cache_dir)
             for obj in tqdm(self.cfg.dataset.classes):
                 cache_pth = os.path.join(cache_dir, '{}.npy'.format(obj))
                 if not os.path.exists(cache_pth):
                     annot_cache = []
                     rgb_pths = glob(
                         os.path.join(ref.lm_train_real_dir, obj,
                                      '*-color.png'))
                     for rgb_pth in tqdm(rgb_pths):
                         item = self.col_train_item(rgb_pth)
                         item['obj'] = obj
                         annot_cache.append(item)
                     np.save(cache_pth, annot_cache)
                 annot.extend(
                     np.load(cache_pth, allow_pickle=True).tolist())
             self.real_num = len(annot)
             logger.info('load {} real training samples.'.format(
                 self.real_num))
         if 'imgn' in self.cfg.dataset.img_type:
             cache_dir = os.path.join(ref.cache_dir, 'train/imgn')
             if not os.path.exists(cache_dir):
                 os.makedirs(cache_dir)
             for obj in tqdm(self.cfg.dataset.classes):
                 cache_pth = os.path.join(cache_dir, '{}.npy'.format(obj))
                 if not os.path.exists(cache_pth):
                     annot_cache = []
                     coor_pths = sorted(
                         glob(
                             os.path.join(ref.lm_train_imgn_dir, obj,
                                          '*-coor.pkl')))
                     for coor_pth in tqdm(coor_pths):
                         item = self.col_imgn_item(coor_pth)
                         item['obj'] = obj
                         annot_cache.append(item)
                     np.save(cache_pth, annot_cache)
                 annot_obj = np.load(cache_pth, allow_pickle=True).tolist()
                 annot_obj_num = len(annot_obj)
                 if (annot_obj_num > self.cfg.dataset.syn_num) and (
                         self.cfg.dataset.syn_samp_type != ''):
                     if self.cfg.dataset.syn_samp_type == 'uniform':
                         samp_idx = np.linspace(0,
                                                annot_obj_num - 1,
                                                self.cfg.dataset.syn_num,
                                                dtype=np.int32)
                     elif self.cfg.dataset.syn_samp_type == 'random':
                         samp_idx = random.sample(range(annot_obj_num),
                                                  self.cfg.dataset.syn_num)
                     else:
                         raise ValueError
                     annot_obj = np.asarray(annot_obj)[samp_idx].tolist()
                 annot.extend(annot_obj)
             self.imgn_num = len(annot) - self.real_num
             logger.info('load {} imgn training samples.'.format(
                 self.imgn_num))
         self.num = len(annot)
         self.annot = annot
         logger.info(
             'load {} training samples, including {} real samples and {} synthetic samples.'
             .format(self.num, self.real_num, self.imgn_num))
         self.bg_list = self.load_bg_list()
     else:
         raise ValueError
Example #3
0
def train(epoch, cfg, data_loader, model, criterions, optimizer=None):
    model.train()
    preds = {}
    Loss = AverageMeter()
    Loss_rot = AverageMeter()
    Loss_trans = AverageMeter()
    num_iters = len(data_loader)
    bar = Bar('{}'.format(cfg.pytorch.exp_id[-60:]), max=num_iters)

    time_monitor = False
    vis_dir = os.path.join(cfg.pytorch.save_path, 'train_vis_{}'.format(epoch))
    if not os.path.exists(vis_dir):
        os.makedirs(vis_dir)
    for i, (obj, obj_id, inp, target, loss_msk, trans_local, pose, c_box,
            s_box, box) in enumerate(data_loader):
        cur_iter = i + (epoch - 1) * num_iters
        if cfg.pytorch.gpu > -1:
            inp_var = inp.cuda(cfg.pytorch.gpu, async=True).float()
            target_var = target.cuda(cfg.pytorch.gpu, async=True).float()
            loss_msk_var = loss_msk.cuda(cfg.pytorch.gpu, async=True).float()
            trans_local_var = trans_local.cuda(cfg.pytorch.gpu,
                                               async=True).float()
            pose_var = pose.cuda(cfg.pytorch.gpu, async=True).float()
            c_box_var = c_box.cuda(cfg.pytorch.gpu, async=True).float()
            s_box_var = s_box.cuda(cfg.pytorch.gpu, async=True).float()
        else:
            inp_var = inp.float()
            target_var = target.float()
            loss_msk_var = loss_msk.float()
            trans_local_var = trans_local.float()
            pose_var = pose.float()
            c_box_var = c_box.float()
            s_box_var = s_box.float()

        bs = len(inp)
        # forward propagation
        T_begin = time.time()
        # import ipdb; ipdb.set_trace()
        pred_rot, pred_trans = model(inp_var)
        T_end = time.time() - T_begin
        if time_monitor:
            logger.info(
                "time for a batch forward of resnet model is {}".format(T_end))

        if i % cfg.test.disp_interval == 0:
            # input image
            inp_rgb = (inp[0].cpu().numpy().copy() * 255)[::-1, :, :].astype(
                np.uint8)
            cfg.writer.add_image('input_image', inp_rgb, i)
            cv2.imwrite(os.path.join(vis_dir, '{}_inp.png'.format(i)),
                        inp_rgb.transpose(1, 2, 0)[:, :, ::-1])
            if 'rot' in cfg.pytorch.task.lower():
                # coordinates map
                pred_coor = pred_rot[0, 0:3].data.cpu().numpy().copy()
                pred_coor[0] = im_norm_255(pred_coor[0])
                pred_coor[1] = im_norm_255(pred_coor[1])
                pred_coor[2] = im_norm_255(pred_coor[2])
                pred_coor = np.asarray(pred_coor, dtype=np.uint8)
                cfg.writer.add_image('train_coor_x_pred',
                                     np.expand_dims(pred_coor[0], axis=0), i)
                cfg.writer.add_image('train_coor_y_pred',
                                     np.expand_dims(pred_coor[1], axis=0), i)
                cfg.writer.add_image('train_coor_z_pred',
                                     np.expand_dims(pred_coor[2], axis=0), i)
                cv2.imwrite(
                    os.path.join(vis_dir, '{}_coor_x_pred.png'.format(i)),
                    pred_coor[0])
                cv2.imwrite(
                    os.path.join(vis_dir, '{}_coor_y_pred.png'.format(i)),
                    pred_coor[1])
                cv2.imwrite(
                    os.path.join(vis_dir, '{}_coor_z_pred.png'.format(i)),
                    pred_coor[2])
                gt_coor = target[0, 0:3].data.cpu().numpy().copy()
                gt_coor[0] = im_norm_255(gt_coor[0])
                gt_coor[1] = im_norm_255(gt_coor[1])
                gt_coor[2] = im_norm_255(gt_coor[2])
                gt_coor = np.asarray(gt_coor, dtype=np.uint8)
                cfg.writer.add_image('train_coor_x_gt',
                                     np.expand_dims(gt_coor[0], axis=0), i)
                cfg.writer.add_image('train_coor_y_gt',
                                     np.expand_dims(gt_coor[1], axis=0), i)
                cfg.writer.add_image('train_coor_z_gt',
                                     np.expand_dims(gt_coor[2], axis=0), i)
                cv2.imwrite(
                    os.path.join(vis_dir, '{}_coor_x_gt.png'.format(i)),
                    gt_coor[0])
                cv2.imwrite(
                    os.path.join(vis_dir, '{}_coor_y_gt.png'.format(i)),
                    gt_coor[1])
                cv2.imwrite(
                    os.path.join(vis_dir, '{}_coor_z_gt.png'.format(i)),
                    gt_coor[2])
                # confidence map
                pred_conf = pred_rot[0, 3].data.cpu().numpy().copy()
                pred_conf = (im_norm_255(pred_conf)).astype(np.uint8)
                gt_conf = target[0, 3].data.cpu().numpy().copy()
                cfg.writer.add_image('train_conf_pred',
                                     np.expand_dims(pred_conf, axis=0), i)
                cfg.writer.add_image('train_conf_gt',
                                     np.expand_dims(gt_conf, axis=0), i)
                cv2.imwrite(os.path.join(vis_dir, '{}_conf_gt.png'.format(i)),
                            gt_conf)
                cv2.imwrite(
                    os.path.join(vis_dir, '{}_conf_pred.png'.format(i)),
                    pred_conf)
            if 'trans' in cfg.pytorch.task.lower():
                pred_trans_ = pred_trans[0].data.cpu().numpy().copy()
                gt_trans_ = trans_local[0].data.cpu().numpy().copy()
                cfg.writer.add_scalar('train_trans_x_gt', gt_trans_[0],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('train_trans_y_gt', gt_trans_[1],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('train_trans_z_gt', gt_trans_[2],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('train_trans_x_pred', pred_trans_[0],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('train_trans_y_pred', pred_trans_[1],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('train_trans_z_pred', pred_trans_[2],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('train_trans_x_err',
                                      pred_trans_[0] - gt_trans_[0],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('train_trans_y_err',
                                      pred_trans_[1] - gt_trans_[1],
                                      i + (epoch - 1) * num_iters)
                cfg.writer.add_scalar('train_trans_z_err',
                                      pred_trans_[2] - gt_trans_[2],
                                      i + (epoch - 1) * num_iters)

        # loss
        if 'rot' in cfg.pytorch.task.lower(
        ) and not cfg.network.rot_head_freeze:
            if cfg.loss.rot_mask_loss:
                loss_rot = criterions[cfg.loss.rot_loss_type](
                    loss_msk_var * pred_rot, loss_msk_var * target_var)
            else:
                loss_rot = criterions[cfg.loss.rot_loss_type](pred_rot,
                                                              target_var)
        else:
            loss_rot = 0
        if 'trans' in cfg.pytorch.task.lower(
        ) and not cfg.network.trans_head_freeze:
            loss_trans = criterions[cfg.loss.trans_loss_type](pred_trans,
                                                              trans_local_var)
        else:
            loss_trans = 0
        loss = cfg.loss.rot_loss_weight * loss_rot + cfg.loss.trans_loss_weight * loss_trans

        Loss.update(loss.item() if loss != 0 else 0, bs)
        Loss_rot.update(loss_rot.item() if loss_rot != 0 else 0, bs)
        Loss_trans.update(loss_trans.item() if loss_trans != 0 else 0, bs)

        cfg.writer.add_scalar('data/loss_rot_trans',
                              loss.item() if loss != 0 else 0, cur_iter)
        cfg.writer.add_scalar('data/loss_rot',
                              loss_rot.item() if loss_rot != 0 else 0,
                              cur_iter)
        cfg.writer.add_scalar('data/loss_trans',
                              loss_trans.item() if loss_trans != 0 else 0,
                              cur_iter)

        optimizer.zero_grad()
        model.zero_grad()
        T_begin = time.time()
        loss.backward()
        optimizer.step()
        T_end = time.time() - T_begin
        if time_monitor:
            logger.info("time for backward of model: {}".format(T_end))

        Bar.suffix = 'train Epoch: [{0}][{1}/{2}]| Total: {total:} | ETA: {eta:} | Loss {loss.avg:.4f} | Loss_rot {loss_rot.avg:.4f} | Loss_trans {loss_trans.avg:.4f}'.format(
            epoch,
            i,
            num_iters,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=Loss,
            loss_rot=Loss_rot,
            loss_trans=Loss_trans)
        bar.next()
    bar.finish()
    return {
        'Loss': Loss.avg,
        'Loss_rot': Loss_rot.avg,
        'Loss_trans': Loss_trans.avg
    }, preds
Example #4
0
def main():
    cfg = config().parse()
    network, optimizer = build_model(cfg)
    criterions = {'L1': torch.nn.L1Loss(), 'L2': torch.nn.MSELoss()}

    if cfg.pytorch.gpu > -1:
        logger.info('Using GPU{}'.format(cfg.pytorch.gpu))
        network = network.cuda(cfg.pytorch.gpu)
        for k in criterions.keys():
            criterions[k] = criterions[k].cuda(cfg.pytorch.gpu)

    def _worker_init_fn():
        torch_seed = torch.initial_seed()
        np_seed = torch_seed // 2**32 - 1
        random.seed(torch_seed)
        np.random.seed(np_seed)

    test_loader = torch.utils.data.DataLoader(
        LM(cfg, 'test'),
        batch_size=cfg.train.train_batch_size
        if 'fast' in cfg.test.test_mode else 1,
        shuffle=False,
        num_workers=int(cfg.pytorch.threads_num),
        worker_init_fn=_worker_init_fn())

    obj_vtx = {}
    logger.info('load 3d object models...')
    for obj in tqdm(cfg.dataset.classes):
        obj_vtx[obj] = load_ply_vtx(
            os.path.join(ref.lm_model_dir, '{}/{}.ply'.format(obj, obj)))
    obj_info = LM.load_lm_model_info(ref.lm_model_info_pth)

    if cfg.pytorch.test:
        _, preds = test(0, cfg, test_loader, network, obj_vtx, obj_info,
                        criterions)
        if preds is not None:
            torch.save({
                'cfg': pprint.pformat(cfg),
                'preds': preds
            }, os.path.join(cfg.pytorch.save_path, 'preds.pth'))
        return

    train_loader = torch.utils.data.DataLoader(
        LM(cfg, 'train'),
        batch_size=cfg.train.train_batch_size,
        shuffle=True,
        num_workers=int(cfg.pytorch.threads_num),
        worker_init_fn=_worker_init_fn())

    for epoch in range(cfg.train.begin_epoch, cfg.train.end_epoch + 1):
        mark = epoch if (cfg.pytorch.save_mode == 'all') else 'last'
        log_dict_train, _ = train(epoch, cfg, train_loader, network,
                                  criterions, optimizer)
        for k, v in log_dict_train.items():
            logger.info('{} {:8f} | '.format(k, v))
        if epoch % cfg.train.test_interval == 0:
            save_model(
                os.path.join(cfg.pytorch.save_path,
                             'model_{}.checkpoint'.format(mark)),
                network)  # optimizer
            log_dict_val, preds = test(epoch, cfg, test_loader, network,
                                       obj_vtx, obj_info, criterions)
        logger.info('\n')
        if epoch in cfg.train.lr_epoch_step:
            if optimizer is not None:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= cfg.train.lr_factor
                    logger.info("drop lr to {}".format(param_group['lr']))

    torch.save(network.cpu(),
               os.path.join(cfg.pytorch.save_path, 'model_cpu.pth'))
Example #5
0
    def evaluate_trans(self):
        '''
        evaluate trans error in detail
        '''
        all_poses_est = copy.deepcopy(self.pose_est_all)
        all_poses_gt = copy.deepcopy(self.pose_gt_all)

        logger.info('\n* {} *\n {:^}\n* {} *'.format('-' * 100, 'Evaluation Translation', '-' * 100))
        rot_thresh_list = np.arange(1, 11, 1)
        trans_thresh_list = np.arange(0.01, 0.11, 0.01)
        num_metric = len(rot_thresh_list)
        num_classes = len(self.classes)

        trans_acc = np.zeros((num_classes, num_metric))
        x_acc = np.zeros((num_classes, num_metric))
        y_acc = np.zeros((num_classes, num_metric))
        z_acc = np.zeros((num_classes, num_metric))

        num_classes = len(self.classes)

        threshold_2 = np.zeros((num_classes, 3), dtype=np.float32)
        threshold_5 = np.zeros((num_classes, 3), dtype=np.float32)
        threshold_10 = np.zeros((num_classes, 3), dtype=np.float32)
        threshold_20 = np.zeros((num_classes, 3), dtype=np.float32)

        for i in range(num_classes):
            for j in range(3):
                threshold_2[i][j] = 2
                threshold_5[i][j] = 5
                threshold_10[i][j] = 10
                threshold_20[i][j] = 20

        num_valid_class = len(self.classes)
        for i, cls_name in enumerate(self.classes):
            curr_poses_gt = all_poses_gt[cls_name]
            curr_poses_est = all_poses_est[cls_name]
            num = len(curr_poses_gt)
            cur_trans_rst = np.zeros((num, 1))
            cur_x_rst = np.zeros((num, 1))
            cur_y_rst = np.zeros((num, 1))
            cur_z_rst = np.zeros((num, 1))

            for j in range(num):
                RT = curr_poses_est[j]  # est pose
                pose_gt = curr_poses_gt[j]  # gt pose
                t_dist_est = LA.norm(RT[:, 3].reshape(3) - pose_gt[:, 3].reshape(3))
                err_xyz = np.abs(RT[:, 3] - pose_gt[:, 3])
                cur_x_rst[j, 0], cur_y_rst[j, 0], cur_z_rst[j, 0] = err_xyz
                cur_trans_rst[j, 0] = t_dist_est

            for thresh_idx in range(num_metric):
                trans_acc[i, thresh_idx] = np.mean(cur_trans_rst < trans_thresh_list[thresh_idx])
                x_acc[i, thresh_idx] = np.mean(cur_x_rst < trans_thresh_list[thresh_idx])
                y_acc[i, thresh_idx] = np.mean(cur_y_rst < trans_thresh_list[thresh_idx])
                z_acc[i, thresh_idx] = np.mean(cur_z_rst < trans_thresh_list[thresh_idx])

            logger.info("------------ {} -----------".format(cls_name))
            logger.info("{:>24}: {:>7}, {:>7}, {:>7}, {:>7}".format("trans_thresh", "TraAcc", "x", "y", "z"))
            show_list = [1, 4, 9]
            for show_idx in show_list:
                logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}, {:>7.2f}".format('average_accuracy',
                                    '{:.2f}'.format(trans_thresh_list[show_idx]),
                                    trans_acc[i, show_idx] * 100, x_acc[i, show_idx] * 100,
                                    y_acc[i, show_idx] * 100, z_acc[i, show_idx] * 100))
        print(' ')
        # overall performance
        show_list = [1, 4, 9]
        logger.info("---------- performance over {} classes -----------".format(num_valid_class))
        logger.info("{:>24}: {:>7}, {:>7}, {:>7}, {:>7}".format("trans_thresh", "TraAcc", "x", "y", "z"))

        for show_idx in show_list:
            logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}, {:>7.2f}".format('average_accuracy',
                                '{:.2f}'.format(trans_thresh_list[show_idx]),
                                np.sum(trans_acc[:, show_idx]) / num_valid_class * 100,
                                np.sum(x_acc[:, show_idx]) / num_valid_class * 100,
                                np.sum(y_acc[:, show_idx]) / num_valid_class * 100,
                                np.sum(z_acc[:, show_idx]) / num_valid_class * 100))
        print(' ')
Example #6
0
    def evaluate_pose(self):
        """
        Evaluate 6D pose and display
        """
        all_poses_est = self.pose_est_all
        all_poses_gt = self.pose_gt_all
        logger.info('\n* {} *\n {:^}\n* {} *'.format('-' * 100, 'Evaluation 6D Pose', '-' * 100))
        rot_thresh_list = np.arange(1, 11, 1)
        trans_thresh_list = np.arange(0.01, 0.11, 0.01)
        num_metric = len(rot_thresh_list)
        num_classes = len(self.classes)
        rot_acc = np.zeros((num_classes, num_metric))
        trans_acc = np.zeros((num_classes, num_metric))
        space_acc = np.zeros((num_classes, num_metric))

        num_valid_class = len(self.classes)
        for i, cls_name in enumerate(self.classes):
            curr_poses_gt = all_poses_gt[cls_name]
            curr_poses_est = all_poses_est[cls_name]
            num = len(curr_poses_gt)
            cur_rot_rst = np.zeros((num, 1))
            cur_trans_rst = np.zeros((num, 1))

            for j in range(num):
                r_dist_est, t_dist_est = calc_rt_dist_m(curr_poses_est[j], curr_poses_gt[j])
                if cls_name == 'eggbox' and r_dist_est > 90:
                    RT_z = np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0]])
                    curr_pose_est_sym = se3_mul(curr_poses_est[j], RT_z)
                    r_dist_est, t_dist_est = calc_rt_dist_m(curr_pose_est_sym, curr_poses_gt[j])
                # logger.info('t_dist: {}'.format(t_dist_est))
                cur_rot_rst[j, 0] = r_dist_est
                cur_trans_rst[j, 0] = t_dist_est

            # cur_rot_rst = np.vstack(all_rot_err[cls_idx, iter_i])
            # cur_trans_rst = np.vstack(all_trans_err[cls_idx, iter_i])
            for thresh_idx in range(num_metric):
                rot_acc[i, thresh_idx] = np.mean(cur_rot_rst < rot_thresh_list[thresh_idx])
                trans_acc[i, thresh_idx] = np.mean(cur_trans_rst < trans_thresh_list[thresh_idx])
                space_acc[i, thresh_idx] = np.mean(np.logical_and(cur_rot_rst < rot_thresh_list[thresh_idx],
                                                                  cur_trans_rst < trans_thresh_list[thresh_idx]))

            logger.info("------------ {} -----------".format(cls_name))
            logger.info("{:>24}: {:>7}, {:>7}, {:>7}".format("[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc"))
            logger.info(
                "{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format('average_accuracy', '[{:>2}, {:.2f}]'.format(-1, -1),
                                                                   np.mean(rot_acc[i, :]) * 100,
                                                                   np.mean(trans_acc[i, :]) * 100,
                                                                   np.mean(space_acc[i, :]) * 100))
            show_list = [1, 4, 9]
            for show_idx in show_list:
                logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}"
                            .format('average_accuracy',
                                    '[{:>2}, {:.2f}]'.format(rot_thresh_list[show_idx], trans_thresh_list[show_idx]),
                                    rot_acc[i, show_idx] * 100, trans_acc[i, show_idx] * 100,
                                    space_acc[i, show_idx] * 100))
        print(' ')
        # overall performance
        show_list = [1, 4, 9]
        logger.info("---------- performance over {} classes -----------".format(num_valid_class))
        logger.info("{:>24}: {:>7}, {:>7}, {:>7}"
                    .format("[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc"))
        logger.info(
            "{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format('average_accuracy', '[{:>2}, {:4.2f}]'.format(-1, -1),
                                                               np.sum(rot_acc[:, :]) / (
                                                                           num_valid_class * num_metric) * 100,
                                                               np.sum(trans_acc[:, :]) / (
                                                                       num_valid_class * num_metric) * 100,
                                                               np.sum(space_acc[:, :]) / (
                                                                       num_valid_class * num_metric) * 100))
        for show_idx in show_list:
            logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}"
                        .format('average_accuracy',
                                '[{:>2}, {:.2f}]'.format(rot_thresh_list[show_idx], trans_thresh_list[show_idx]),
                                np.sum(rot_acc[:, show_idx]) / num_valid_class * 100,
                                np.sum(trans_acc[:, show_idx]) / num_valid_class * 100,
                                np.sum(space_acc[:, show_idx]) / num_valid_class * 100))
        print(' ')
Example #7
0
    def evaluate_pose_arp_2d(self, output_dir):
        '''
        evaluate average re-projection 2d error
        '''
        all_poses_est = self.pose_est_all
        all_poses_gt = self.pose_gt_all
        models = self.models
        logger.info('\n* {} *\n {:^}\n* {} *'.format('-' * 100, 'Metric ARP_2D (Average Re-Projection 2D)', '-' * 100))
        K = self.camera_matrix
        num_classes = len(self.classes)
        count_all = np.zeros((num_classes), dtype=np.float32)
        count_correct = {k: np.zeros((num_classes), dtype=np.float32) for k in ['2', '5', '10', '20']}

        threshold_2 = np.zeros((num_classes), dtype=np.float32)
        threshold_5 = np.zeros((num_classes), dtype=np.float32)
        threshold_10 = np.zeros((num_classes), dtype=np.float32)
        threshold_20 = np.zeros((num_classes), dtype=np.float32)
        dx = 0.1
        threshold_mean = np.tile(np.arange(0, 50, dx).astype(np.float32),
                                 (num_classes, 1))  # (num_class, num_iter, num_thresh)
        num_thresh = threshold_mean.shape[-1]
        count_correct['mean'] = np.zeros((num_classes, num_thresh), dtype=np.float32)

        for i in range(num_classes):
            threshold_2[i] = 2
            threshold_5[i] = 5
            threshold_10[i] = 10
            threshold_20[i] = 20

        num_valid_class = len(self.classes)
        for i, cls_name in enumerate(self.classes):
            curr_poses_gt = all_poses_gt[cls_name]
            curr_poses_est = all_poses_est[cls_name]
            num = len(curr_poses_gt)
            count_all[i] = num
            for j in range(num):
                RT = curr_poses_est[j]  # est pose
                pose_gt = curr_poses_gt[j]  # gt pose
                error_rotation = re(RT[:3, :3], pose_gt[:3, :3])
                if cls_name == 'eggbox' and error_rotation > 90:
                    RT_z = np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0]])
                    RT_sym = se3_mul(RT, RT_z)
                    error = arp_2d(RT_sym[:3, :3], RT_sym[:, 3], pose_gt[:3, :3], pose_gt[:, 3],
                                   models[cls_name], K)
                else:
                    error = arp_2d(RT[:3, :3], RT[:, 3], pose_gt[:3, :3], pose_gt[:, 3],
                                   models[cls_name], K)

                if error < threshold_2[i]: count_correct['2'][i] += 1
                if error < threshold_5[i]: count_correct['5'][i] += 1
                if error < threshold_10[i]: count_correct['10'][i] += 1
                if error < threshold_20[i]: count_correct['20'][i] += 1
                for thresh_i in range(num_thresh):
                    if error < threshold_mean[i, thresh_i]:
                        count_correct['mean'][i, thresh_i] += 1

        # store plot data
        plot_data = {}
        sum_acc_mean = np.zeros(1)
        sum_acc_02 = np.zeros(1)
        sum_acc_05 = np.zeros(1)
        sum_acc_10 = np.zeros(1)
        sum_acc_20 = np.zeros(1)
        for i, cls_name in enumerate(self.classes):
            if count_all[i] == 0:
                continue
            plot_data[cls_name] = []
            logger.info("** {} **".format(cls_name))
            from scipy.integrate import simps
            area = simps(count_correct['mean'][i] / float(count_all[i]), dx=dx) / (50.0)
            acc_mean = area * 100
            sum_acc_mean[0] += acc_mean
            acc_02 = 100 * float(count_correct['2'][i]) / float(count_all[i])
            sum_acc_02[0] += acc_02
            acc_05 = 100 * float(count_correct['5'][i]) / float(count_all[i])
            sum_acc_05[0] += acc_05
            acc_10 = 100 * float(count_correct['10'][i]) / float(count_all[i])
            sum_acc_10[0] += acc_10
            acc_20 = 100 * float(count_correct['20'][i]) / float(count_all[i])
            sum_acc_20[0] += acc_20

            fig = plt.figure()
            x_s = np.arange(0, 50, dx).astype(np.float32)
            y_s = 100 * count_correct['mean'][i] / float(count_all[i])
            plot_data[cls_name].append((x_s, y_s))
            plt.plot(x_s, y_s, '-')
            plt.xlim(0, 50)
            plt.ylim(0, 100)
            plt.grid(True)
            plt.xlabel("px")
            plt.ylabel("correctly estimated poses in %")
            plt.savefig(os.path.join(output_dir, 'arp_2d_{}.png'.format(cls_name)), dpi=fig.dpi)
            plt.close()

            logger.info('threshold=[0, 50], area: {:.2f}'.format(acc_mean))
            logger.info('threshold=2, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format(
                count_correct['2'][i], count_all[i], acc_02))
            logger.info('threshold=5, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format(
                count_correct['5'][i], count_all[i], acc_05))
            logger.info(
                'threshold=10, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format(
                    count_correct['10'][i], count_all[i], acc_10))
            logger.info(
                'threshold=20, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format(
                    count_correct['20'][i], count_all[i], acc_20))
            logger.info(" ")

        with open(os.path.join(output_dir, 'arp_2d_xys.pkl'), 'wb') as f:
            cPickle.dump(plot_data, f, protocol=2)
        logger.info("=" * 30)
        logger.info(' ')
        # overall performance of arp 2d
        for iter_i in range(1):
            logger.info("---------- arp 2d performance over {} classes -----------".format(num_valid_class))
            logger.info("** iter {} **".format(iter_i + 1))
            logger.info('threshold=[0, 50], area: {:.2f}'.format(
                sum_acc_mean[iter_i] / num_valid_class))
            logger.info('threshold=2, mean accuracy: {:.2f}'.format(
                sum_acc_02[iter_i] / num_valid_class))
            logger.info('threshold=5, mean accuracy: {:.2f}'.format(
                sum_acc_05[iter_i] / num_valid_class))
            logger.info('threshold=10, mean accuracy: {:.2f}'.format(
                sum_acc_10[iter_i] / num_valid_class))
            logger.info('threshold=20, mean accuracy: {:.2f}'.format(
                sum_acc_20[iter_i] / num_valid_class))
            logger.info(" ")
        logger.info("=" * 30)
Example #8
0
    def evaluate_pose_add(self, output_dir):
        """
        Evaluate 6D pose by ADD Metric
        """
        all_poses_est = self.pose_est_all
        all_poses_gt = self.pose_gt_all
        models_info = self.models_info
        models = self.models
        logger.info('\n* {} *\n {:^}\n* {} *'.format('-' * 100, 'Metric ADD', '-' * 100))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        eval_method = 'add'
        num_classes = len(self.classes)
        count_all = np.zeros((num_classes), dtype=np.float32)
        count_correct = {k: np.zeros((num_classes), dtype=np.float32) for k in ['0.02', '0.05', '0.10']}

        threshold_002 = np.zeros((num_classes), dtype=np.float32)
        threshold_005 = np.zeros((num_classes), dtype=np.float32)
        threshold_010 = np.zeros((num_classes), dtype=np.float32)
        dx = 0.0001
        threshold_mean = np.tile(np.arange(0, 0.1, dx).astype(np.float32), (num_classes, 1))  # (num_class, num_thresh)
        num_thresh = threshold_mean.shape[-1]
        count_correct['mean'] = np.zeros((num_classes, num_thresh), dtype=np.float32)

        self.classes = sorted(self.classes)
        num_valid_class = len(self.classes)
        for i, cls_name in enumerate(self.classes):
            threshold_002[i] = 0.02 * models_info[ref.obj2idx(cls_name)]['diameter']
            threshold_005[i] = 0.05 * models_info[ref.obj2idx(cls_name)]['diameter']
            threshold_010[i] = 0.10 * models_info[ref.obj2idx(cls_name)]['diameter']
            threshold_mean[i, :] *= models_info[ref.obj2idx(cls_name)]['diameter']
            curr_poses_gt = all_poses_gt[cls_name]
            curr_poses_est = all_poses_est[cls_name]
            num = len(curr_poses_gt)
            count_all[i] = num
            for j in range(num):
                RT = curr_poses_est[j]  # est pose
                pose_gt = curr_poses_gt[j]  # gt pose
                if cls_name == 'eggbox' or cls_name == 'glue' or cls_name == 'bowl' or cls_name == 'cup':
                    eval_method = 'adi'
                    error = adi(RT[:3, :3], RT[:, 3], pose_gt[:3, :3], pose_gt[:, 3], models[cls_name])
                else:
                    error = add(RT[:3, :3], RT[:, 3], pose_gt[:3, :3], pose_gt[:, 3], models[cls_name])
                if error < threshold_002[i]:
                    count_correct['0.02'][i] += 1
                if error < threshold_005[i]:
                    count_correct['0.05'][i] += 1
                if error < threshold_010[i]:
                    count_correct['0.10'][i] += 1
                for thresh_i in range(num_thresh):
                    if error < threshold_mean[i, thresh_i]:
                        count_correct['mean'][i, thresh_i] += 1

        plot_data = {}
        sum_acc_mean = np.zeros(1)
        sum_acc_002 = np.zeros(1)
        sum_acc_005 = np.zeros(1)
        sum_acc_010 = np.zeros(1)
        for i, cls_name in enumerate(self.classes):
            if count_all[i] == 0:
                continue
            plot_data[cls_name] = []
            logger.info("** {} **".format(cls_name))
            from scipy.integrate import simps
            area = simps(count_correct['mean'][i] / float(count_all[i]), dx=dx) / 0.1
            acc_mean = area * 100
            sum_acc_mean[0] += acc_mean
            acc_002 = 100 * float(count_correct['0.02'][i]) / float(count_all[i])
            sum_acc_002[0] += acc_002
            acc_005 = 100 * float(count_correct['0.05'][i]) / float(count_all[i])
            sum_acc_005[0] += acc_005
            acc_010 = 100 * float(count_correct['0.10'][i]) / float(count_all[i])
            sum_acc_010[0] += acc_010

            fig = plt.figure()
            x_s = np.arange(0, 0.1, dx).astype(np.float32)
            y_s = count_correct['mean'][i] / float(count_all[i])
            plot_data[cls_name].append((x_s, y_s))
            plt.plot(x_s, y_s, '-')
            plt.xlim(0, 0.1)
            plt.ylim(0, 1)
            plt.xlabel("Average distance threshold in meter (symmetry)")
            plt.ylabel("accuracy")
            plt.savefig(os.path.join(output_dir, 'acc_thres_{}.png'.format(cls_name, )), dpi=fig.dpi)
            plt.close()
            logger.info('threshold=[0.0, 0.10], area: {:.2f}'.format(acc_mean))
            logger.info('threshold=0.02, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format(
                count_correct['0.02'][i],
                count_all[i],
                acc_002))
            logger.info('threshold=0.05, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format(
                count_correct['0.05'][i],
                count_all[i],
                acc_005))
            logger.info('threshold=0.10, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format(
                count_correct['0.10'][i],
                count_all[i],
                acc_010))
            logger.info(" ")

        with open(os.path.join(output_dir, '{}_xys.pkl'.format(eval_method)), 'wb') as f:
            cPickle.dump(plot_data, f, protocol=2)

        logger.info("=" * 30)
        logger.info(' ')
        # overall performance of add
        for iter_i in range(1):
            logger.info("---------- add performance over {} classes -----------".format(num_valid_class))
            logger.info("** iter {} **".format(iter_i + 1))
            logger.info('threshold=[0.0, 0.10], area: {:.2f}'.format(
                sum_acc_mean[iter_i] / num_valid_class))
            logger.info('threshold=0.02, mean accuracy: {:.2f}'.format(
                sum_acc_002[iter_i] / num_valid_class))
            logger.info('threshold=0.05, mean accuracy: {:.2f}'.format(
                sum_acc_005[iter_i] / num_valid_class))
            logger.info('threshold=0.10, mean accuracy: {:.2f}'.format(
                sum_acc_010[iter_i] / num_valid_class))
            logger.info(' ')
        logger.info("=" * 30)
Example #9
0
def build_model(cfg):
    ## get model and optimizer
    if 'resnet' in cfg.network.arch:
        params_lr_list = []
        # backbone net
        block_type, layers, channels, name = resnet_spec[
            cfg.network.back_layers_num]
        backbone_net = ResNetBackboneNet(block_type, layers,
                                         cfg.network.back_input_channel,
                                         cfg.network.back_freeze)
        if cfg.network.back_freeze:
            for param in backbone_net.parameters():
                with torch.no_grad():
                    param.requires_grad = False
        else:
            params_lr_list.append({
                'params':
                filter(lambda p: p.requires_grad, backbone_net.parameters()),
                'lr':
                float(cfg.train.lr_backbone)
            })
        # rotation head net
        rot_head_net = RotHeadNet(channels[-1], cfg.network.rot_layers_num,
                                  cfg.network.rot_filters_num,
                                  cfg.network.rot_conv_kernel_size,
                                  cfg.network.rot_output_conv_kernel_size,
                                  cfg.network.rot_output_channels,
                                  cfg.network.rot_head_freeze)
        if cfg.network.rot_head_freeze:
            for param in rot_head_net.parameters():
                with torch.no_grad():
                    param.requires_grad = False
        else:
            params_lr_list.append({
                'params':
                filter(lambda p: p.requires_grad, rot_head_net.parameters()),
                'lr':
                float(cfg.train.lr_rot_head)
            })
        # translation head net
        trans_head_net = TransHeadNet(channels[-1],
                                      cfg.network.trans_layers_num,
                                      cfg.network.trans_filters_num,
                                      cfg.network.trans_conv_kernel_size,
                                      cfg.network.trans_output_channels,
                                      cfg.network.trans_head_freeze)
        if cfg.network.trans_head_freeze:
            for param in trans_head_net.parameters():
                with torch.no_grad():
                    param.requires_grad = False
        else:
            params_lr_list.append({
                'params':
                filter(lambda p: p.requires_grad, trans_head_net.parameters()),
                'lr':
                float(cfg.train.lr_trans_head)
            })
        # CDPN (Coordinates-based Disentangled Pose Network)
        model = CDPN(backbone_net, rot_head_net, trans_head_net)
        # get optimizer
        if params_lr_list != []:
            optimizer = torch.optim.RMSprop(params_lr_list,
                                            alpha=cfg.train.alpha,
                                            eps=float(cfg.train.epsilon),
                                            weight_decay=cfg.train.weightDecay,
                                            momentum=cfg.train.momentum)
        else:
            optimizer = None

    ## model initialization
    if cfg.pytorch.load_model != '':
        logger.info("=> loading model '{}'".format(cfg.pytorch.load_model))
        checkpoint = torch.load(cfg.pytorch.load_model,
                                map_location=lambda storage, loc: storage)
        if type(checkpoint) == type({}):
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint.state_dict()

        if 'resnet' in cfg.network.arch:
            model_dict = model.state_dict()
            # filter out unnecessary params
            filtered_state_dict = {
                k: v
                for k, v in state_dict.items() if k in model_dict
            }
            # update state dict
            model_dict.update(filtered_state_dict)
            # load params to net
            model.load_state_dict(model_dict)
    else:
        if 'resnet' in cfg.network.arch:
            logger.info(
                "=> loading official model from model zoo for backbone")
            _, _, _, name = resnet_spec[cfg.network.back_layers_num]
            official_resnet = model_zoo.load_url(model_urls[name])
            # drop original resnet fc layer, add 'None' in case of no fc layer, that will raise error
            official_resnet.pop('fc.weight', None)
            official_resnet.pop('fc.bias', None)
            model.backbone.load_state_dict(official_resnet)

    return model, optimizer