def test(epoch, cfg, data_loader, model, obj_vtx, obj_info, criterions): model.eval() Eval = Evaluation(cfg.dataset, obj_info, obj_vtx) if 'trans' in cfg.pytorch.task.lower(): Eval_trans = Evaluation(cfg.dataset, obj_info, obj_vtx) if not cfg.test.ignore_cache_file: est_cache_file = cfg.test.cache_file # gt_cache_file = cfg.test.cache_file.replace('pose_est', 'pose_gt') gt_cache_file = cfg.test.cache_file.replace('_est', '_gt') if os.path.exists(est_cache_file) and os.path.exists(gt_cache_file): Eval.pose_est_all = np.load(est_cache_file, allow_pickle=True).tolist() Eval.pose_gt_all = np.load(gt_cache_file, allow_pickle=True).tolist() fig_save_path = os.path.join(cfg.pytorch.save_path, str(epoch)) mkdir_p(fig_save_path) if 'all' in cfg.test.test_mode.lower(): Eval.evaluate_pose() Eval.evaluate_pose_add(fig_save_path) Eval.evaluate_pose_arp_2d(fig_save_path) elif 'pose' in cfg.test.test_mode.lower(): Eval.evaluate_pose() elif 'add' in cfg.test.test_mode.lower(): Eval.evaluate_pose_add(fig_save_path) elif 'arp' in cfg.test.test_mode.lower(): Eval.evaluate_pose_arp_2d(fig_save_path) else: raise Exception("Wrong test mode: {}".format( cfg.test.test_mode)) return None, None else: logger.info("test cache file {} and {} not exist!".format( est_cache_file, gt_cache_file)) userAns = input("Generating cache file from model [Y(y)/N(n)]:") if userAns.lower() == 'n': sys.exit(0) else: logger.info("Generating test cache file!") preds = {} Loss = AverageMeter() Loss_rot = AverageMeter() Loss_trans = AverageMeter() num_iters = len(data_loader) bar = Bar('{}'.format(cfg.pytorch.exp_id[-60:]), max=num_iters) time_monitor = False vis_dir = os.path.join(cfg.pytorch.save_path, 'test_vis_{}'.format(epoch)) if not os.path.exists(vis_dir): os.makedirs(vis_dir) for i, (obj, obj_id, inp, pose, c_box, s_box, box, trans_local) in enumerate(data_loader): if cfg.pytorch.gpu > -1: inp_var = inp.cuda(cfg.pytorch.gpu, async=True).float() else: inp_var = inp.float() bs = len(inp) # forward propagation T_begin = time.time() pred_rot, pred_trans = model(inp_var) T_end = time.time() - T_begin if time_monitor: logger.info( "time for a batch forward of resnet model is {}".format(T_end)) if i % cfg.test.disp_interval == 0: # input image inp_rgb = (inp[0].cpu().numpy().copy() * 255)[[2, 1, 0], :, :].astype(np.uint8) cfg.writer.add_image('input_image', inp_rgb, i) cv2.imwrite(os.path.join(vis_dir, '{}_inp.png'.format(i)), inp_rgb.transpose(1, 2, 0)[:, :, ::-1]) if 'rot' in cfg.pytorch.task.lower(): # coordinates map pred_coor = pred_rot[0, 0:3].data.cpu().numpy().copy() pred_coor[0] = im_norm_255(pred_coor[0]) pred_coor[1] = im_norm_255(pred_coor[1]) pred_coor[2] = im_norm_255(pred_coor[2]) pred_coor = np.asarray(pred_coor, dtype=np.uint8) cfg.writer.add_image('test_coor_x_pred', np.expand_dims(pred_coor[0], axis=0), i) cfg.writer.add_image('test_coor_y_pred', np.expand_dims(pred_coor[1], axis=0), i) cfg.writer.add_image('test_coor_z_pred', np.expand_dims(pred_coor[2], axis=0), i) # gt_coor = target[0, 0:3].data.cpu().numpy().copy() # gt_coor[0] = im_norm_255(gt_coor[0]) # gt_coor[1] = im_norm_255(gt_coor[1]) # gt_coor[2] = im_norm_255(gt_coor[2]) # gt_coor = np.asarray(gt_coor, dtype=np.uint8) # cfg.writer.add_image('test_coor_x_gt', np.expand_dims(gt_coor[0], axis=0), i) # cfg.writer.add_image('test_coor_y_gt', np.expand_dims(gt_coor[1], axis=0), i) # cfg.writer.add_image('test_coor_z_gt', np.expand_dims(gt_coor[2], axis=0), i) # confidence map pred_conf = pred_rot[0, 3].data.cpu().numpy().copy() pred_conf = (im_norm_255(pred_conf)).astype(np.uint8) cfg.writer.add_image('test_conf_pred', np.expand_dims(pred_conf, axis=0), i) # gt_conf = target[0, 3].data.cpu().numpy().copy() # cfg.writer.add_image('test_conf_gt', np.expand_dims(gt_conf, axis=0), i) if 'trans' in cfg.pytorch.task.lower(): pred_trans_ = pred_trans[0].data.cpu().numpy().copy() gt_trans_ = trans_local[0].data.cpu().numpy().copy() cfg.writer.add_scalar('test_trans_x_gt', gt_trans_[0], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('test_trans_y_gt', gt_trans_[1], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('test_trans_z_gt', gt_trans_[2], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('test_trans_x_pred', pred_trans_[0], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('test_trans_y_pred', pred_trans_[1], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('test_trans_z_pred', pred_trans_[2], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('test_trans_x_err', np.abs(pred_trans_[0] - gt_trans_[0]), i + (epoch - 1) * num_iters) cfg.writer.add_scalar('test_trans_y_err', np.abs(pred_trans_[1] - gt_trans_[1]), i + (epoch - 1) * num_iters) cfg.writer.add_scalar('test_trans_z_err', np.abs(pred_trans_[2] - gt_trans_[2]), i + (epoch - 1) * num_iters) if 'rot' in cfg.pytorch.task.lower(): pred_coor = pred_rot[:, 0:3].data.cpu().numpy().copy() pred_conf = pred_rot[:, 3].data.cpu().numpy().copy() else: pred_coor = np.zeros(bs) pred_conf = np.zeros(bs) if 'trans' in cfg.pytorch.task.lower(): pred_trans = pred_trans.data.cpu().numpy().copy() else: pred_trans = np.zeros(bs) col = list( zip(obj, obj_id.numpy(), pred_coor, pred_conf, pred_trans, pose.numpy(), c_box.numpy(), s_box.numpy(), box.numpy())) for idx in range(len(col)): obj_, obj_id_, pred_coor_, pred_conf_, pred_trans_, pose_gt, c_box_, s_box_, box_ = col[ idx] T_begin = time.time() if 'rot' in cfg.pytorch.task.lower(): # building 2D-3D correspondences pred_coor_ = pred_coor_.transpose(1, 2, 0) pred_coor_[:, :, 0] = pred_coor_[:, :, 0] * abs( obj_info[obj_id_]['min_x']) pred_coor_[:, :, 1] = pred_coor_[:, :, 1] * abs( obj_info[obj_id_]['min_y']) pred_coor_[:, :, 2] = pred_coor_[:, :, 2] * abs( obj_info[obj_id_]['min_z']) pred_coor_ = pred_coor_.tolist() eroMask = False if eroMask: kernel = np.ones((3, 3), np.uint8) pred_conf_ = cv2.erode(pred_conf_, kernel) pred_conf_ = (pred_conf_ - pred_conf_.min()) / ( pred_conf_.max() - pred_conf_.min()) pred_conf_ = pred_conf_.tolist() select_pts_2d = [] select_pts_3d = [] c_w = int(c_box_[0]) c_h = int(c_box_[1]) s = int(s_box_) w_begin = c_w - s / 2. h_begin = c_h - s / 2. w_unit = s * 1.0 / cfg.dataiter.out_res h_unit = s * 1.0 / cfg.dataiter.out_res min_x = 0.001 * abs(obj_info[obj_id_]['min_x']) min_y = 0.001 * abs(obj_info[obj_id_]['min_y']) min_z = 0.001 * abs(obj_info[obj_id_]['min_z']) for x in range(cfg.dataiter.out_res): for y in range(cfg.dataiter.out_res): if pred_conf_[x][y] < cfg.test.mask_threshold: continue if abs(pred_coor_[x][y][0]) < min_x and abs(pred_coor_[x][y][1]) < min_y and \ abs(pred_coor_[x][y][2]) < min_z: continue select_pts_2d.append( [w_begin + y * w_unit, h_begin + x * h_unit]) select_pts_3d.append(pred_coor_[x][y]) model_points = np.asarray(select_pts_3d, dtype=np.float32) image_points = np.asarray(select_pts_2d, dtype=np.float32) if 'trans' in cfg.pytorch.task.lower(): # compute T from translation head ratio_delta_c = pred_trans_[:2] ratio_depth = pred_trans_[2] pred_depth = ratio_depth * (cfg.dataiter.out_res / s_box_) pred_c = ratio_delta_c * box_[2:] + c_box_ pred_x = (pred_c[0] - cfg.dataset.camera_matrix[0, 2] ) * pred_depth / cfg.dataset.camera_matrix[0, 0] pred_y = (pred_c[1] - cfg.dataset.camera_matrix[1, 2] ) * pred_depth / cfg.dataset.camera_matrix[1, 1] T_vector_trans = np.asarray([pred_x, pred_y, pred_depth]) pose_est_trans = np.concatenate( (np.eye(3), np.asarray((T_vector_trans).reshape(3, 1))), axis=1) try: if 'rot' in cfg.pytorch.task.lower(): dist_coeffs = np.zeros( (4, 1)) # Assuming no lens distortion if cfg.test.pnp == 'iterPnP': # iterative PnP algorithm success, R_vector, T_vector = cv2.solvePnP( model_points, image_points, cfg.dataset.camera_matrix, dist_coeffs, flags=cv2.SOLVEPNP_ITERATIVE) elif cfg.test.pnp == 'ransac': # ransac algorithm _, R_vector, T_vector, inliers = cv2.solvePnPRansac( model_points, image_points, cfg.dataset.camera_matrix, dist_coeffs, flags=cv2.SOLVEPNP_EPNP) else: raise NotImplementedError( "Not support PnP algorithm: {}".format( cfg.test.pnp)) R_matrix = cv2.Rodrigues(R_vector, jacobian=0)[0] pose_est = np.concatenate( (R_matrix, np.asarray(T_vector).reshape(3, 1)), axis=1) if 'trans' in cfg.pytorch.task.lower(): pose_est_trans = np.concatenate( (R_matrix, np.asarray((T_vector_trans).reshape(3, 1))), axis=1) Eval.pose_est_all[obj_].append(pose_est) Eval.pose_gt_all[obj_].append(pose_gt) Eval.num[obj_] += 1 Eval.numAll += 1 if 'trans' in cfg.pytorch.task.lower(): Eval_trans.pose_est_all[obj_].append(pose_est_trans) Eval_trans.pose_gt_all[obj_].append(pose_gt) Eval_trans.num[obj_] += 1 Eval_trans.numAll += 1 except: Eval.num[obj_] += 1 Eval.numAll += 1 if 'trans' in cfg.pytorch.task.lower(): Eval_trans.num[obj_] += 1 Eval_trans.numAll += 1 logger.info('error in solve PnP or Ransac') T_end = time.time() - T_begin if time_monitor: logger.info( "time spend on PnP+RANSAC for one image is {}".format( T_end)) Bar.suffix = 'test Epoch: [{0}][{1}/{2}]| Total: {total:} | ETA: {eta:} | Loss {loss.avg:.4f} | Loss_rot {loss_rot.avg:.4f} | Loss_trans {loss_trans.avg:.4f}'.format( epoch, i, num_iters, total=bar.elapsed_td, eta=bar.eta_td, loss=Loss, loss_rot=Loss_rot, loss_trans=Loss_trans) bar.next() epoch_save_path = os.path.join(cfg.pytorch.save_path, str(epoch)) if not os.path.exists(epoch_save_path): os.makedirs(epoch_save_path) if 'rot' in cfg.pytorch.task.lower(): logger.info("{} Evaluate of Rotation Branch of Epoch {} {}".format( '-' * 40, epoch, '-' * 40)) preds['poseGT'] = Eval.pose_gt_all preds['poseEst'] = Eval.pose_est_all if cfg.pytorch.test: np.save(os.path.join(epoch_save_path, 'pose_est_all_test.npy'), Eval.pose_est_all) np.save(os.path.join(epoch_save_path, 'pose_gt_all_test.npy'), Eval.pose_gt_all) else: np.save( os.path.join(epoch_save_path, 'pose_est_all_epoch{}.npy'.format(epoch)), Eval.pose_est_all) np.save( os.path.join(epoch_save_path, 'pose_gt_all_epoch{}.npy'.format(epoch)), Eval.pose_gt_all) # evaluation if 'all' in cfg.test.test_mode.lower(): Eval.evaluate_pose() Eval.evaluate_pose_add(epoch_save_path) Eval.evaluate_pose_arp_2d(epoch_save_path) else: if 'pose' in cfg.test.test_mode.lower(): Eval.evaluate_pose() if 'add' in cfg.test.test_mode.lower(): Eval.evaluate_pose_add(epoch_save_path) if 'arp' in cfg.test.test_mode.lower(): Eval.evaluate_pose_arp_2d(epoch_save_path) if 'trans' in cfg.pytorch.task.lower(): logger.info("{} Evaluate of Translation Branch of Epoch {} {}".format( '-' * 40, epoch, '-' * 40)) preds['poseGT'] = Eval_trans.pose_gt_all preds['poseEst'] = Eval_trans.pose_est_all if cfg.pytorch.test: np.save( os.path.join(epoch_save_path, 'pose_est_all_test_trans.npy'), Eval_trans.pose_est_all) np.save( os.path.join(epoch_save_path, 'pose_gt_all_test_trans.npy'), Eval_trans.pose_gt_all) else: np.save( os.path.join(epoch_save_path, 'pose_est_all_trans_epoch{}.npy'.format(epoch)), Eval_trans.pose_est_all) np.save( os.path.join(epoch_save_path, 'pose_gt_all_trans_epoch{}.npy'.format(epoch)), Eval_trans.pose_gt_all) # evaluation if 'all' in cfg.test.test_mode.lower(): Eval_trans.evaluate_pose() Eval_trans.evaluate_pose_add(epoch_save_path) Eval_trans.evaluate_pose_arp_2d(epoch_save_path) else: if 'pose' in cfg.test.test_mode.lower(): Eval_trans.evaluate_pose() if 'add' in cfg.test.test_mode.lower(): Eval_trans.evaluate_pose_add(epoch_save_path) if 'arp' in cfg.test.test_mode.lower(): Eval_trans.evaluate_pose_arp_2d(epoch_save_path) bar.finish() return { 'Loss': Loss.avg, 'Loss_rot': Loss_rot.avg, 'Loss_trans': Loss_trans.avg }, preds
def __init__(self, cfg, split): self.cfg = cfg self.split = split self.infos = self.load_lm_model_info(ref.lm_model_info_pth) self.cam_K = ref.K logger.info('==> initializing {} {} data.'.format( cfg.dataset.name, split)) # load dataset annot = [] if split == 'test': cache_dir = os.path.join(ref.cache_dir, 'test') if not os.path.exists(cache_dir): os.makedirs(cache_dir) for obj in tqdm(self.cfg.dataset.classes): cache_pth = os.path.join(cache_dir, '{}.npy'.format(obj)) if not os.path.exists(cache_pth): annot_cache = [] rgb_pths = glob( os.path.join(ref.lm_test_dir, obj, '*-color.png')) for rgb_pth in tqdm(rgb_pths): item = self.col_test_item(rgb_pth) item['obj'] = obj annot_cache.append(item) np.save(cache_pth, annot_cache) annot.extend(np.load(cache_pth, allow_pickle=True).tolist()) self.num = len(annot) self.annot = annot logger.info('load {} test samples.'.format(self.num)) elif split == 'train': if 'real' in self.cfg.dataset.img_type: cache_dir = os.path.join(ref.cache_dir, 'train/real') if not os.path.exists(cache_dir): os.makedirs(cache_dir) for obj in tqdm(self.cfg.dataset.classes): cache_pth = os.path.join(cache_dir, '{}.npy'.format(obj)) if not os.path.exists(cache_pth): annot_cache = [] rgb_pths = glob( os.path.join(ref.lm_train_real_dir, obj, '*-color.png')) for rgb_pth in tqdm(rgb_pths): item = self.col_train_item(rgb_pth) item['obj'] = obj annot_cache.append(item) np.save(cache_pth, annot_cache) annot.extend( np.load(cache_pth, allow_pickle=True).tolist()) self.real_num = len(annot) logger.info('load {} real training samples.'.format( self.real_num)) if 'imgn' in self.cfg.dataset.img_type: cache_dir = os.path.join(ref.cache_dir, 'train/imgn') if not os.path.exists(cache_dir): os.makedirs(cache_dir) for obj in tqdm(self.cfg.dataset.classes): cache_pth = os.path.join(cache_dir, '{}.npy'.format(obj)) if not os.path.exists(cache_pth): annot_cache = [] coor_pths = sorted( glob( os.path.join(ref.lm_train_imgn_dir, obj, '*-coor.pkl'))) for coor_pth in tqdm(coor_pths): item = self.col_imgn_item(coor_pth) item['obj'] = obj annot_cache.append(item) np.save(cache_pth, annot_cache) annot_obj = np.load(cache_pth, allow_pickle=True).tolist() annot_obj_num = len(annot_obj) if (annot_obj_num > self.cfg.dataset.syn_num) and ( self.cfg.dataset.syn_samp_type != ''): if self.cfg.dataset.syn_samp_type == 'uniform': samp_idx = np.linspace(0, annot_obj_num - 1, self.cfg.dataset.syn_num, dtype=np.int32) elif self.cfg.dataset.syn_samp_type == 'random': samp_idx = random.sample(range(annot_obj_num), self.cfg.dataset.syn_num) else: raise ValueError annot_obj = np.asarray(annot_obj)[samp_idx].tolist() annot.extend(annot_obj) self.imgn_num = len(annot) - self.real_num logger.info('load {} imgn training samples.'.format( self.imgn_num)) self.num = len(annot) self.annot = annot logger.info( 'load {} training samples, including {} real samples and {} synthetic samples.' .format(self.num, self.real_num, self.imgn_num)) self.bg_list = self.load_bg_list() else: raise ValueError
def train(epoch, cfg, data_loader, model, criterions, optimizer=None): model.train() preds = {} Loss = AverageMeter() Loss_rot = AverageMeter() Loss_trans = AverageMeter() num_iters = len(data_loader) bar = Bar('{}'.format(cfg.pytorch.exp_id[-60:]), max=num_iters) time_monitor = False vis_dir = os.path.join(cfg.pytorch.save_path, 'train_vis_{}'.format(epoch)) if not os.path.exists(vis_dir): os.makedirs(vis_dir) for i, (obj, obj_id, inp, target, loss_msk, trans_local, pose, c_box, s_box, box) in enumerate(data_loader): cur_iter = i + (epoch - 1) * num_iters if cfg.pytorch.gpu > -1: inp_var = inp.cuda(cfg.pytorch.gpu, async=True).float() target_var = target.cuda(cfg.pytorch.gpu, async=True).float() loss_msk_var = loss_msk.cuda(cfg.pytorch.gpu, async=True).float() trans_local_var = trans_local.cuda(cfg.pytorch.gpu, async=True).float() pose_var = pose.cuda(cfg.pytorch.gpu, async=True).float() c_box_var = c_box.cuda(cfg.pytorch.gpu, async=True).float() s_box_var = s_box.cuda(cfg.pytorch.gpu, async=True).float() else: inp_var = inp.float() target_var = target.float() loss_msk_var = loss_msk.float() trans_local_var = trans_local.float() pose_var = pose.float() c_box_var = c_box.float() s_box_var = s_box.float() bs = len(inp) # forward propagation T_begin = time.time() # import ipdb; ipdb.set_trace() pred_rot, pred_trans = model(inp_var) T_end = time.time() - T_begin if time_monitor: logger.info( "time for a batch forward of resnet model is {}".format(T_end)) if i % cfg.test.disp_interval == 0: # input image inp_rgb = (inp[0].cpu().numpy().copy() * 255)[::-1, :, :].astype( np.uint8) cfg.writer.add_image('input_image', inp_rgb, i) cv2.imwrite(os.path.join(vis_dir, '{}_inp.png'.format(i)), inp_rgb.transpose(1, 2, 0)[:, :, ::-1]) if 'rot' in cfg.pytorch.task.lower(): # coordinates map pred_coor = pred_rot[0, 0:3].data.cpu().numpy().copy() pred_coor[0] = im_norm_255(pred_coor[0]) pred_coor[1] = im_norm_255(pred_coor[1]) pred_coor[2] = im_norm_255(pred_coor[2]) pred_coor = np.asarray(pred_coor, dtype=np.uint8) cfg.writer.add_image('train_coor_x_pred', np.expand_dims(pred_coor[0], axis=0), i) cfg.writer.add_image('train_coor_y_pred', np.expand_dims(pred_coor[1], axis=0), i) cfg.writer.add_image('train_coor_z_pred', np.expand_dims(pred_coor[2], axis=0), i) cv2.imwrite( os.path.join(vis_dir, '{}_coor_x_pred.png'.format(i)), pred_coor[0]) cv2.imwrite( os.path.join(vis_dir, '{}_coor_y_pred.png'.format(i)), pred_coor[1]) cv2.imwrite( os.path.join(vis_dir, '{}_coor_z_pred.png'.format(i)), pred_coor[2]) gt_coor = target[0, 0:3].data.cpu().numpy().copy() gt_coor[0] = im_norm_255(gt_coor[0]) gt_coor[1] = im_norm_255(gt_coor[1]) gt_coor[2] = im_norm_255(gt_coor[2]) gt_coor = np.asarray(gt_coor, dtype=np.uint8) cfg.writer.add_image('train_coor_x_gt', np.expand_dims(gt_coor[0], axis=0), i) cfg.writer.add_image('train_coor_y_gt', np.expand_dims(gt_coor[1], axis=0), i) cfg.writer.add_image('train_coor_z_gt', np.expand_dims(gt_coor[2], axis=0), i) cv2.imwrite( os.path.join(vis_dir, '{}_coor_x_gt.png'.format(i)), gt_coor[0]) cv2.imwrite( os.path.join(vis_dir, '{}_coor_y_gt.png'.format(i)), gt_coor[1]) cv2.imwrite( os.path.join(vis_dir, '{}_coor_z_gt.png'.format(i)), gt_coor[2]) # confidence map pred_conf = pred_rot[0, 3].data.cpu().numpy().copy() pred_conf = (im_norm_255(pred_conf)).astype(np.uint8) gt_conf = target[0, 3].data.cpu().numpy().copy() cfg.writer.add_image('train_conf_pred', np.expand_dims(pred_conf, axis=0), i) cfg.writer.add_image('train_conf_gt', np.expand_dims(gt_conf, axis=0), i) cv2.imwrite(os.path.join(vis_dir, '{}_conf_gt.png'.format(i)), gt_conf) cv2.imwrite( os.path.join(vis_dir, '{}_conf_pred.png'.format(i)), pred_conf) if 'trans' in cfg.pytorch.task.lower(): pred_trans_ = pred_trans[0].data.cpu().numpy().copy() gt_trans_ = trans_local[0].data.cpu().numpy().copy() cfg.writer.add_scalar('train_trans_x_gt', gt_trans_[0], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('train_trans_y_gt', gt_trans_[1], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('train_trans_z_gt', gt_trans_[2], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('train_trans_x_pred', pred_trans_[0], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('train_trans_y_pred', pred_trans_[1], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('train_trans_z_pred', pred_trans_[2], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('train_trans_x_err', pred_trans_[0] - gt_trans_[0], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('train_trans_y_err', pred_trans_[1] - gt_trans_[1], i + (epoch - 1) * num_iters) cfg.writer.add_scalar('train_trans_z_err', pred_trans_[2] - gt_trans_[2], i + (epoch - 1) * num_iters) # loss if 'rot' in cfg.pytorch.task.lower( ) and not cfg.network.rot_head_freeze: if cfg.loss.rot_mask_loss: loss_rot = criterions[cfg.loss.rot_loss_type]( loss_msk_var * pred_rot, loss_msk_var * target_var) else: loss_rot = criterions[cfg.loss.rot_loss_type](pred_rot, target_var) else: loss_rot = 0 if 'trans' in cfg.pytorch.task.lower( ) and not cfg.network.trans_head_freeze: loss_trans = criterions[cfg.loss.trans_loss_type](pred_trans, trans_local_var) else: loss_trans = 0 loss = cfg.loss.rot_loss_weight * loss_rot + cfg.loss.trans_loss_weight * loss_trans Loss.update(loss.item() if loss != 0 else 0, bs) Loss_rot.update(loss_rot.item() if loss_rot != 0 else 0, bs) Loss_trans.update(loss_trans.item() if loss_trans != 0 else 0, bs) cfg.writer.add_scalar('data/loss_rot_trans', loss.item() if loss != 0 else 0, cur_iter) cfg.writer.add_scalar('data/loss_rot', loss_rot.item() if loss_rot != 0 else 0, cur_iter) cfg.writer.add_scalar('data/loss_trans', loss_trans.item() if loss_trans != 0 else 0, cur_iter) optimizer.zero_grad() model.zero_grad() T_begin = time.time() loss.backward() optimizer.step() T_end = time.time() - T_begin if time_monitor: logger.info("time for backward of model: {}".format(T_end)) Bar.suffix = 'train Epoch: [{0}][{1}/{2}]| Total: {total:} | ETA: {eta:} | Loss {loss.avg:.4f} | Loss_rot {loss_rot.avg:.4f} | Loss_trans {loss_trans.avg:.4f}'.format( epoch, i, num_iters, total=bar.elapsed_td, eta=bar.eta_td, loss=Loss, loss_rot=Loss_rot, loss_trans=Loss_trans) bar.next() bar.finish() return { 'Loss': Loss.avg, 'Loss_rot': Loss_rot.avg, 'Loss_trans': Loss_trans.avg }, preds
def main(): cfg = config().parse() network, optimizer = build_model(cfg) criterions = {'L1': torch.nn.L1Loss(), 'L2': torch.nn.MSELoss()} if cfg.pytorch.gpu > -1: logger.info('Using GPU{}'.format(cfg.pytorch.gpu)) network = network.cuda(cfg.pytorch.gpu) for k in criterions.keys(): criterions[k] = criterions[k].cuda(cfg.pytorch.gpu) def _worker_init_fn(): torch_seed = torch.initial_seed() np_seed = torch_seed // 2**32 - 1 random.seed(torch_seed) np.random.seed(np_seed) test_loader = torch.utils.data.DataLoader( LM(cfg, 'test'), batch_size=cfg.train.train_batch_size if 'fast' in cfg.test.test_mode else 1, shuffle=False, num_workers=int(cfg.pytorch.threads_num), worker_init_fn=_worker_init_fn()) obj_vtx = {} logger.info('load 3d object models...') for obj in tqdm(cfg.dataset.classes): obj_vtx[obj] = load_ply_vtx( os.path.join(ref.lm_model_dir, '{}/{}.ply'.format(obj, obj))) obj_info = LM.load_lm_model_info(ref.lm_model_info_pth) if cfg.pytorch.test: _, preds = test(0, cfg, test_loader, network, obj_vtx, obj_info, criterions) if preds is not None: torch.save({ 'cfg': pprint.pformat(cfg), 'preds': preds }, os.path.join(cfg.pytorch.save_path, 'preds.pth')) return train_loader = torch.utils.data.DataLoader( LM(cfg, 'train'), batch_size=cfg.train.train_batch_size, shuffle=True, num_workers=int(cfg.pytorch.threads_num), worker_init_fn=_worker_init_fn()) for epoch in range(cfg.train.begin_epoch, cfg.train.end_epoch + 1): mark = epoch if (cfg.pytorch.save_mode == 'all') else 'last' log_dict_train, _ = train(epoch, cfg, train_loader, network, criterions, optimizer) for k, v in log_dict_train.items(): logger.info('{} {:8f} | '.format(k, v)) if epoch % cfg.train.test_interval == 0: save_model( os.path.join(cfg.pytorch.save_path, 'model_{}.checkpoint'.format(mark)), network) # optimizer log_dict_val, preds = test(epoch, cfg, test_loader, network, obj_vtx, obj_info, criterions) logger.info('\n') if epoch in cfg.train.lr_epoch_step: if optimizer is not None: for param_group in optimizer.param_groups: param_group['lr'] *= cfg.train.lr_factor logger.info("drop lr to {}".format(param_group['lr'])) torch.save(network.cpu(), os.path.join(cfg.pytorch.save_path, 'model_cpu.pth'))
def evaluate_trans(self): ''' evaluate trans error in detail ''' all_poses_est = copy.deepcopy(self.pose_est_all) all_poses_gt = copy.deepcopy(self.pose_gt_all) logger.info('\n* {} *\n {:^}\n* {} *'.format('-' * 100, 'Evaluation Translation', '-' * 100)) rot_thresh_list = np.arange(1, 11, 1) trans_thresh_list = np.arange(0.01, 0.11, 0.01) num_metric = len(rot_thresh_list) num_classes = len(self.classes) trans_acc = np.zeros((num_classes, num_metric)) x_acc = np.zeros((num_classes, num_metric)) y_acc = np.zeros((num_classes, num_metric)) z_acc = np.zeros((num_classes, num_metric)) num_classes = len(self.classes) threshold_2 = np.zeros((num_classes, 3), dtype=np.float32) threshold_5 = np.zeros((num_classes, 3), dtype=np.float32) threshold_10 = np.zeros((num_classes, 3), dtype=np.float32) threshold_20 = np.zeros((num_classes, 3), dtype=np.float32) for i in range(num_classes): for j in range(3): threshold_2[i][j] = 2 threshold_5[i][j] = 5 threshold_10[i][j] = 10 threshold_20[i][j] = 20 num_valid_class = len(self.classes) for i, cls_name in enumerate(self.classes): curr_poses_gt = all_poses_gt[cls_name] curr_poses_est = all_poses_est[cls_name] num = len(curr_poses_gt) cur_trans_rst = np.zeros((num, 1)) cur_x_rst = np.zeros((num, 1)) cur_y_rst = np.zeros((num, 1)) cur_z_rst = np.zeros((num, 1)) for j in range(num): RT = curr_poses_est[j] # est pose pose_gt = curr_poses_gt[j] # gt pose t_dist_est = LA.norm(RT[:, 3].reshape(3) - pose_gt[:, 3].reshape(3)) err_xyz = np.abs(RT[:, 3] - pose_gt[:, 3]) cur_x_rst[j, 0], cur_y_rst[j, 0], cur_z_rst[j, 0] = err_xyz cur_trans_rst[j, 0] = t_dist_est for thresh_idx in range(num_metric): trans_acc[i, thresh_idx] = np.mean(cur_trans_rst < trans_thresh_list[thresh_idx]) x_acc[i, thresh_idx] = np.mean(cur_x_rst < trans_thresh_list[thresh_idx]) y_acc[i, thresh_idx] = np.mean(cur_y_rst < trans_thresh_list[thresh_idx]) z_acc[i, thresh_idx] = np.mean(cur_z_rst < trans_thresh_list[thresh_idx]) logger.info("------------ {} -----------".format(cls_name)) logger.info("{:>24}: {:>7}, {:>7}, {:>7}, {:>7}".format("trans_thresh", "TraAcc", "x", "y", "z")) show_list = [1, 4, 9] for show_idx in show_list: logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}, {:>7.2f}".format('average_accuracy', '{:.2f}'.format(trans_thresh_list[show_idx]), trans_acc[i, show_idx] * 100, x_acc[i, show_idx] * 100, y_acc[i, show_idx] * 100, z_acc[i, show_idx] * 100)) print(' ') # overall performance show_list = [1, 4, 9] logger.info("---------- performance over {} classes -----------".format(num_valid_class)) logger.info("{:>24}: {:>7}, {:>7}, {:>7}, {:>7}".format("trans_thresh", "TraAcc", "x", "y", "z")) for show_idx in show_list: logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}, {:>7.2f}".format('average_accuracy', '{:.2f}'.format(trans_thresh_list[show_idx]), np.sum(trans_acc[:, show_idx]) / num_valid_class * 100, np.sum(x_acc[:, show_idx]) / num_valid_class * 100, np.sum(y_acc[:, show_idx]) / num_valid_class * 100, np.sum(z_acc[:, show_idx]) / num_valid_class * 100)) print(' ')
def evaluate_pose(self): """ Evaluate 6D pose and display """ all_poses_est = self.pose_est_all all_poses_gt = self.pose_gt_all logger.info('\n* {} *\n {:^}\n* {} *'.format('-' * 100, 'Evaluation 6D Pose', '-' * 100)) rot_thresh_list = np.arange(1, 11, 1) trans_thresh_list = np.arange(0.01, 0.11, 0.01) num_metric = len(rot_thresh_list) num_classes = len(self.classes) rot_acc = np.zeros((num_classes, num_metric)) trans_acc = np.zeros((num_classes, num_metric)) space_acc = np.zeros((num_classes, num_metric)) num_valid_class = len(self.classes) for i, cls_name in enumerate(self.classes): curr_poses_gt = all_poses_gt[cls_name] curr_poses_est = all_poses_est[cls_name] num = len(curr_poses_gt) cur_rot_rst = np.zeros((num, 1)) cur_trans_rst = np.zeros((num, 1)) for j in range(num): r_dist_est, t_dist_est = calc_rt_dist_m(curr_poses_est[j], curr_poses_gt[j]) if cls_name == 'eggbox' and r_dist_est > 90: RT_z = np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0]]) curr_pose_est_sym = se3_mul(curr_poses_est[j], RT_z) r_dist_est, t_dist_est = calc_rt_dist_m(curr_pose_est_sym, curr_poses_gt[j]) # logger.info('t_dist: {}'.format(t_dist_est)) cur_rot_rst[j, 0] = r_dist_est cur_trans_rst[j, 0] = t_dist_est # cur_rot_rst = np.vstack(all_rot_err[cls_idx, iter_i]) # cur_trans_rst = np.vstack(all_trans_err[cls_idx, iter_i]) for thresh_idx in range(num_metric): rot_acc[i, thresh_idx] = np.mean(cur_rot_rst < rot_thresh_list[thresh_idx]) trans_acc[i, thresh_idx] = np.mean(cur_trans_rst < trans_thresh_list[thresh_idx]) space_acc[i, thresh_idx] = np.mean(np.logical_and(cur_rot_rst < rot_thresh_list[thresh_idx], cur_trans_rst < trans_thresh_list[thresh_idx])) logger.info("------------ {} -----------".format(cls_name)) logger.info("{:>24}: {:>7}, {:>7}, {:>7}".format("[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc")) logger.info( "{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format('average_accuracy', '[{:>2}, {:.2f}]'.format(-1, -1), np.mean(rot_acc[i, :]) * 100, np.mean(trans_acc[i, :]) * 100, np.mean(space_acc[i, :]) * 100)) show_list = [1, 4, 9] for show_idx in show_list: logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}" .format('average_accuracy', '[{:>2}, {:.2f}]'.format(rot_thresh_list[show_idx], trans_thresh_list[show_idx]), rot_acc[i, show_idx] * 100, trans_acc[i, show_idx] * 100, space_acc[i, show_idx] * 100)) print(' ') # overall performance show_list = [1, 4, 9] logger.info("---------- performance over {} classes -----------".format(num_valid_class)) logger.info("{:>24}: {:>7}, {:>7}, {:>7}" .format("[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc")) logger.info( "{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format('average_accuracy', '[{:>2}, {:4.2f}]'.format(-1, -1), np.sum(rot_acc[:, :]) / ( num_valid_class * num_metric) * 100, np.sum(trans_acc[:, :]) / ( num_valid_class * num_metric) * 100, np.sum(space_acc[:, :]) / ( num_valid_class * num_metric) * 100)) for show_idx in show_list: logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}" .format('average_accuracy', '[{:>2}, {:.2f}]'.format(rot_thresh_list[show_idx], trans_thresh_list[show_idx]), np.sum(rot_acc[:, show_idx]) / num_valid_class * 100, np.sum(trans_acc[:, show_idx]) / num_valid_class * 100, np.sum(space_acc[:, show_idx]) / num_valid_class * 100)) print(' ')
def evaluate_pose_arp_2d(self, output_dir): ''' evaluate average re-projection 2d error ''' all_poses_est = self.pose_est_all all_poses_gt = self.pose_gt_all models = self.models logger.info('\n* {} *\n {:^}\n* {} *'.format('-' * 100, 'Metric ARP_2D (Average Re-Projection 2D)', '-' * 100)) K = self.camera_matrix num_classes = len(self.classes) count_all = np.zeros((num_classes), dtype=np.float32) count_correct = {k: np.zeros((num_classes), dtype=np.float32) for k in ['2', '5', '10', '20']} threshold_2 = np.zeros((num_classes), dtype=np.float32) threshold_5 = np.zeros((num_classes), dtype=np.float32) threshold_10 = np.zeros((num_classes), dtype=np.float32) threshold_20 = np.zeros((num_classes), dtype=np.float32) dx = 0.1 threshold_mean = np.tile(np.arange(0, 50, dx).astype(np.float32), (num_classes, 1)) # (num_class, num_iter, num_thresh) num_thresh = threshold_mean.shape[-1] count_correct['mean'] = np.zeros((num_classes, num_thresh), dtype=np.float32) for i in range(num_classes): threshold_2[i] = 2 threshold_5[i] = 5 threshold_10[i] = 10 threshold_20[i] = 20 num_valid_class = len(self.classes) for i, cls_name in enumerate(self.classes): curr_poses_gt = all_poses_gt[cls_name] curr_poses_est = all_poses_est[cls_name] num = len(curr_poses_gt) count_all[i] = num for j in range(num): RT = curr_poses_est[j] # est pose pose_gt = curr_poses_gt[j] # gt pose error_rotation = re(RT[:3, :3], pose_gt[:3, :3]) if cls_name == 'eggbox' and error_rotation > 90: RT_z = np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0]]) RT_sym = se3_mul(RT, RT_z) error = arp_2d(RT_sym[:3, :3], RT_sym[:, 3], pose_gt[:3, :3], pose_gt[:, 3], models[cls_name], K) else: error = arp_2d(RT[:3, :3], RT[:, 3], pose_gt[:3, :3], pose_gt[:, 3], models[cls_name], K) if error < threshold_2[i]: count_correct['2'][i] += 1 if error < threshold_5[i]: count_correct['5'][i] += 1 if error < threshold_10[i]: count_correct['10'][i] += 1 if error < threshold_20[i]: count_correct['20'][i] += 1 for thresh_i in range(num_thresh): if error < threshold_mean[i, thresh_i]: count_correct['mean'][i, thresh_i] += 1 # store plot data plot_data = {} sum_acc_mean = np.zeros(1) sum_acc_02 = np.zeros(1) sum_acc_05 = np.zeros(1) sum_acc_10 = np.zeros(1) sum_acc_20 = np.zeros(1) for i, cls_name in enumerate(self.classes): if count_all[i] == 0: continue plot_data[cls_name] = [] logger.info("** {} **".format(cls_name)) from scipy.integrate import simps area = simps(count_correct['mean'][i] / float(count_all[i]), dx=dx) / (50.0) acc_mean = area * 100 sum_acc_mean[0] += acc_mean acc_02 = 100 * float(count_correct['2'][i]) / float(count_all[i]) sum_acc_02[0] += acc_02 acc_05 = 100 * float(count_correct['5'][i]) / float(count_all[i]) sum_acc_05[0] += acc_05 acc_10 = 100 * float(count_correct['10'][i]) / float(count_all[i]) sum_acc_10[0] += acc_10 acc_20 = 100 * float(count_correct['20'][i]) / float(count_all[i]) sum_acc_20[0] += acc_20 fig = plt.figure() x_s = np.arange(0, 50, dx).astype(np.float32) y_s = 100 * count_correct['mean'][i] / float(count_all[i]) plot_data[cls_name].append((x_s, y_s)) plt.plot(x_s, y_s, '-') plt.xlim(0, 50) plt.ylim(0, 100) plt.grid(True) plt.xlabel("px") plt.ylabel("correctly estimated poses in %") plt.savefig(os.path.join(output_dir, 'arp_2d_{}.png'.format(cls_name)), dpi=fig.dpi) plt.close() logger.info('threshold=[0, 50], area: {:.2f}'.format(acc_mean)) logger.info('threshold=2, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format( count_correct['2'][i], count_all[i], acc_02)) logger.info('threshold=5, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format( count_correct['5'][i], count_all[i], acc_05)) logger.info( 'threshold=10, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format( count_correct['10'][i], count_all[i], acc_10)) logger.info( 'threshold=20, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format( count_correct['20'][i], count_all[i], acc_20)) logger.info(" ") with open(os.path.join(output_dir, 'arp_2d_xys.pkl'), 'wb') as f: cPickle.dump(plot_data, f, protocol=2) logger.info("=" * 30) logger.info(' ') # overall performance of arp 2d for iter_i in range(1): logger.info("---------- arp 2d performance over {} classes -----------".format(num_valid_class)) logger.info("** iter {} **".format(iter_i + 1)) logger.info('threshold=[0, 50], area: {:.2f}'.format( sum_acc_mean[iter_i] / num_valid_class)) logger.info('threshold=2, mean accuracy: {:.2f}'.format( sum_acc_02[iter_i] / num_valid_class)) logger.info('threshold=5, mean accuracy: {:.2f}'.format( sum_acc_05[iter_i] / num_valid_class)) logger.info('threshold=10, mean accuracy: {:.2f}'.format( sum_acc_10[iter_i] / num_valid_class)) logger.info('threshold=20, mean accuracy: {:.2f}'.format( sum_acc_20[iter_i] / num_valid_class)) logger.info(" ") logger.info("=" * 30)
def evaluate_pose_add(self, output_dir): """ Evaluate 6D pose by ADD Metric """ all_poses_est = self.pose_est_all all_poses_gt = self.pose_gt_all models_info = self.models_info models = self.models logger.info('\n* {} *\n {:^}\n* {} *'.format('-' * 100, 'Metric ADD', '-' * 100)) if not os.path.exists(output_dir): os.makedirs(output_dir) eval_method = 'add' num_classes = len(self.classes) count_all = np.zeros((num_classes), dtype=np.float32) count_correct = {k: np.zeros((num_classes), dtype=np.float32) for k in ['0.02', '0.05', '0.10']} threshold_002 = np.zeros((num_classes), dtype=np.float32) threshold_005 = np.zeros((num_classes), dtype=np.float32) threshold_010 = np.zeros((num_classes), dtype=np.float32) dx = 0.0001 threshold_mean = np.tile(np.arange(0, 0.1, dx).astype(np.float32), (num_classes, 1)) # (num_class, num_thresh) num_thresh = threshold_mean.shape[-1] count_correct['mean'] = np.zeros((num_classes, num_thresh), dtype=np.float32) self.classes = sorted(self.classes) num_valid_class = len(self.classes) for i, cls_name in enumerate(self.classes): threshold_002[i] = 0.02 * models_info[ref.obj2idx(cls_name)]['diameter'] threshold_005[i] = 0.05 * models_info[ref.obj2idx(cls_name)]['diameter'] threshold_010[i] = 0.10 * models_info[ref.obj2idx(cls_name)]['diameter'] threshold_mean[i, :] *= models_info[ref.obj2idx(cls_name)]['diameter'] curr_poses_gt = all_poses_gt[cls_name] curr_poses_est = all_poses_est[cls_name] num = len(curr_poses_gt) count_all[i] = num for j in range(num): RT = curr_poses_est[j] # est pose pose_gt = curr_poses_gt[j] # gt pose if cls_name == 'eggbox' or cls_name == 'glue' or cls_name == 'bowl' or cls_name == 'cup': eval_method = 'adi' error = adi(RT[:3, :3], RT[:, 3], pose_gt[:3, :3], pose_gt[:, 3], models[cls_name]) else: error = add(RT[:3, :3], RT[:, 3], pose_gt[:3, :3], pose_gt[:, 3], models[cls_name]) if error < threshold_002[i]: count_correct['0.02'][i] += 1 if error < threshold_005[i]: count_correct['0.05'][i] += 1 if error < threshold_010[i]: count_correct['0.10'][i] += 1 for thresh_i in range(num_thresh): if error < threshold_mean[i, thresh_i]: count_correct['mean'][i, thresh_i] += 1 plot_data = {} sum_acc_mean = np.zeros(1) sum_acc_002 = np.zeros(1) sum_acc_005 = np.zeros(1) sum_acc_010 = np.zeros(1) for i, cls_name in enumerate(self.classes): if count_all[i] == 0: continue plot_data[cls_name] = [] logger.info("** {} **".format(cls_name)) from scipy.integrate import simps area = simps(count_correct['mean'][i] / float(count_all[i]), dx=dx) / 0.1 acc_mean = area * 100 sum_acc_mean[0] += acc_mean acc_002 = 100 * float(count_correct['0.02'][i]) / float(count_all[i]) sum_acc_002[0] += acc_002 acc_005 = 100 * float(count_correct['0.05'][i]) / float(count_all[i]) sum_acc_005[0] += acc_005 acc_010 = 100 * float(count_correct['0.10'][i]) / float(count_all[i]) sum_acc_010[0] += acc_010 fig = plt.figure() x_s = np.arange(0, 0.1, dx).astype(np.float32) y_s = count_correct['mean'][i] / float(count_all[i]) plot_data[cls_name].append((x_s, y_s)) plt.plot(x_s, y_s, '-') plt.xlim(0, 0.1) plt.ylim(0, 1) plt.xlabel("Average distance threshold in meter (symmetry)") plt.ylabel("accuracy") plt.savefig(os.path.join(output_dir, 'acc_thres_{}.png'.format(cls_name, )), dpi=fig.dpi) plt.close() logger.info('threshold=[0.0, 0.10], area: {:.2f}'.format(acc_mean)) logger.info('threshold=0.02, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format( count_correct['0.02'][i], count_all[i], acc_002)) logger.info('threshold=0.05, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format( count_correct['0.05'][i], count_all[i], acc_005)) logger.info('threshold=0.10, correct poses: {}, all poses: {}, accuracy: {:.2f}'.format( count_correct['0.10'][i], count_all[i], acc_010)) logger.info(" ") with open(os.path.join(output_dir, '{}_xys.pkl'.format(eval_method)), 'wb') as f: cPickle.dump(plot_data, f, protocol=2) logger.info("=" * 30) logger.info(' ') # overall performance of add for iter_i in range(1): logger.info("---------- add performance over {} classes -----------".format(num_valid_class)) logger.info("** iter {} **".format(iter_i + 1)) logger.info('threshold=[0.0, 0.10], area: {:.2f}'.format( sum_acc_mean[iter_i] / num_valid_class)) logger.info('threshold=0.02, mean accuracy: {:.2f}'.format( sum_acc_002[iter_i] / num_valid_class)) logger.info('threshold=0.05, mean accuracy: {:.2f}'.format( sum_acc_005[iter_i] / num_valid_class)) logger.info('threshold=0.10, mean accuracy: {:.2f}'.format( sum_acc_010[iter_i] / num_valid_class)) logger.info(' ') logger.info("=" * 30)
def build_model(cfg): ## get model and optimizer if 'resnet' in cfg.network.arch: params_lr_list = [] # backbone net block_type, layers, channels, name = resnet_spec[ cfg.network.back_layers_num] backbone_net = ResNetBackboneNet(block_type, layers, cfg.network.back_input_channel, cfg.network.back_freeze) if cfg.network.back_freeze: for param in backbone_net.parameters(): with torch.no_grad(): param.requires_grad = False else: params_lr_list.append({ 'params': filter(lambda p: p.requires_grad, backbone_net.parameters()), 'lr': float(cfg.train.lr_backbone) }) # rotation head net rot_head_net = RotHeadNet(channels[-1], cfg.network.rot_layers_num, cfg.network.rot_filters_num, cfg.network.rot_conv_kernel_size, cfg.network.rot_output_conv_kernel_size, cfg.network.rot_output_channels, cfg.network.rot_head_freeze) if cfg.network.rot_head_freeze: for param in rot_head_net.parameters(): with torch.no_grad(): param.requires_grad = False else: params_lr_list.append({ 'params': filter(lambda p: p.requires_grad, rot_head_net.parameters()), 'lr': float(cfg.train.lr_rot_head) }) # translation head net trans_head_net = TransHeadNet(channels[-1], cfg.network.trans_layers_num, cfg.network.trans_filters_num, cfg.network.trans_conv_kernel_size, cfg.network.trans_output_channels, cfg.network.trans_head_freeze) if cfg.network.trans_head_freeze: for param in trans_head_net.parameters(): with torch.no_grad(): param.requires_grad = False else: params_lr_list.append({ 'params': filter(lambda p: p.requires_grad, trans_head_net.parameters()), 'lr': float(cfg.train.lr_trans_head) }) # CDPN (Coordinates-based Disentangled Pose Network) model = CDPN(backbone_net, rot_head_net, trans_head_net) # get optimizer if params_lr_list != []: optimizer = torch.optim.RMSprop(params_lr_list, alpha=cfg.train.alpha, eps=float(cfg.train.epsilon), weight_decay=cfg.train.weightDecay, momentum=cfg.train.momentum) else: optimizer = None ## model initialization if cfg.pytorch.load_model != '': logger.info("=> loading model '{}'".format(cfg.pytorch.load_model)) checkpoint = torch.load(cfg.pytorch.load_model, map_location=lambda storage, loc: storage) if type(checkpoint) == type({}): state_dict = checkpoint['state_dict'] else: state_dict = checkpoint.state_dict() if 'resnet' in cfg.network.arch: model_dict = model.state_dict() # filter out unnecessary params filtered_state_dict = { k: v for k, v in state_dict.items() if k in model_dict } # update state dict model_dict.update(filtered_state_dict) # load params to net model.load_state_dict(model_dict) else: if 'resnet' in cfg.network.arch: logger.info( "=> loading official model from model zoo for backbone") _, _, _, name = resnet_spec[cfg.network.back_layers_num] official_resnet = model_zoo.load_url(model_urls[name]) # drop original resnet fc layer, add 'None' in case of no fc layer, that will raise error official_resnet.pop('fc.weight', None) official_resnet.pop('fc.bias', None) model.backbone.load_state_dict(official_resnet) return model, optimizer