def main(args): setseeds.set_all_seeds(args.manual_seed) # Initialize hosting dat_str = "_".join(args.train_datasets) now = datetime.now() split_str = "_".join(args.train_splits) exp_id = ( f"checkpoints/{dat_str}_{split_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}/" #f"{args.com}_frac{args.fraction:.1e}" #f"lr{args.lr}_mom{args.momentum}_bs{args.batch_size}_" #f"_lmbeta{args.mano_lambda_shape:.1e}" #f"_lmpr{args.mano_lambda_pose_reg:.1e}" #f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}" #f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}" #f"seed{args.manual_seed}" ) # if args.no_augm: # exp_id = f"{exp_id}_no_augm" # if args.block_rot: # exp_id = f"{exp_id}_block_rot" # if args.freeze_batchnorm: # exp_id = f"{exp_id}_fbn" # Initialize local checkpoint folder print(f"Saving experiment logs, models, and training curves and images at {exp_id}") save_args(args, exp_id, "opt") img_folder = os.path.join(exp_id, "images") os.makedirs(img_folder, exist_ok=True) result_folder = os.path.join(exp_id, "results") result_path = os.path.join(result_folder, "results.pkl") os.makedirs(result_folder, exist_ok=True) pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}") with open(pyapt_path, "a") as t_f: t_f.write(" ") loaders = [] if not args.evaluate: for train_split, dat_name in zip(args.train_splits, args.train_datasets): train_dataset, _ = get_dataset.get_dataset( dat_name, split=train_split, meta={"version": args.version, "split_mode": args.split_mode}, use_cache=args.use_cache, mini_factor=args.mini_factor, no_augm=args.no_augm, block_rot=args.block_rot, max_rot=args.max_rot, center_idx=args.center_idx, scale_jittering=args.scale_jittering, center_jittering=args.center_jittering, fraction=args.fraction, mode="strong", sample_nb=None, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers / len(args.train_datasets)), drop_last=True, collate_fn=collate.meshreg_collate, ) loaders.append(train_loader) loader = concatloader.ConcatDataloader(loaders) if args.evaluate or args.eval_freq != -1: val_dataset, _ = get_dataset.get_dataset( args.val_dataset, split=args.val_split, meta={"version": args.version, "split_mode": args.split_mode}, use_cache=args.use_cache, mini_factor=args.mini_factor, no_augm=True, block_rot=args.block_rot, max_rot=args.max_rot, center_idx=args.center_idx, scale_jittering=args.scale_jittering, center_jittering=args.center_jittering, sample_nb=None, ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), drop_last=False, collate_fn=collate.meshreg_collate, ) model = MeshRegNet( mano_center_idx=args.center_idx, mano_lambda_joints2d=args.mano_lambda_joints2d, mano_lambda_joints3d=args.mano_lambda_joints3d, mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d, mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d, mano_lambda_verts2d=args.mano_lambda_verts2d, mano_lambda_verts3d=args.mano_lambda_verts3d, mano_lambda_shape=args.mano_lambda_shape, mano_use_shape=args.mano_lambda_shape > 0, mano_lambda_pose_reg=args.mano_lambda_pose_reg, obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d, obj_lambda_verts2d=args.obj_lambda_verts2d, obj_lambda_verts3d=args.obj_lambda_verts3d, obj_trans_factor=args.obj_trans_factor, obj_scale_factor=args.obj_scale_factor, mano_fhb_hand="fhbhands" in args.train_datasets, ) model.cuda() # Initalize model if args.resume is not None: opts = reloadmodel.load_opts(args.resume) model, epoch = reloadmodel.reload_model(args.resume, opts) model.cuda() if args.evaluate: args.epochs = epoch + 1 else: epoch = 0 if args.freeze_batchnorm: freeze.freeze_batchnorm_stats(model) # Freeze batchnorm model_params = filter(lambda p: p.requires_grad, model.parameters()) if args.optimizer == "adam": optimizer = torch.optim.Adam(model_params, lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == "sgd": optimizer = torch.optim.SGD( model_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay ) if args.resume is not None: reloadmodel.reload_optimizer(args.resume, optimizer) if args.lr_decay_gamma: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma) fig = plt.figure(figsize=(10, 10)) save_results = {} save_results["opt"] = dict(vars(args)) save_results["train_losses"] = [] save_results["val_losses"] = [] monitor = MetricMonitor() print("Will monitor metrics.") for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"): if not args.freeze_batchnorm: model.train() else: model.eval() if not args.evaluate: save_dict = epochpass.epoch_pass( loader, model, train=True, optimizer=optimizer, scheduler=scheduler, epoch=epoch_idx, img_folder=img_folder, fig=fig, display_freq=args.display_freq, epoch_display_freq=args.epoch_display_freq, lr_decay_gamma=args.lr_decay_gamma, monitor=monitor, ) save_results["train_losses"].append(save_dict) with open(result_path, "wb") as p_f: pickle.dump(save_results, p_f) modelio.save_checkpoint( { "epoch": epoch_idx + 1, "network": "correspnet", "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler, }, is_best=True, checkpoint=exp_id, snapshot=args.snapshot, ) if args.evaluate or (args.eval_freq != -1 and epoch_idx % args.eval_freq == 0): val_save_dict = epochpass.epoch_pass( val_loader, model, train=False, optimizer=None, scheduler=None, epoch=epoch_idx, img_folder=img_folder, fig=fig, display_freq=args.display_freq, epoch_display_freq=args.epoch_display_freq, lr_decay_gamma=args.lr_decay_gamma, monitor=monitor, ) save_results["val_losses"].append(val_save_dict) monitor.plot(os.path.join(exp_id, "training.html"), plotly=True) monitor.plot_histogram(os.path.join(exp_id, "evaluation.html"), plotly=True) if args.matplotlib: monitor.plot(result_folder, matplotlib=True) monitor.plot_histogram(result_folder, matplotlib=True) monitor.save_metrics(os.path.join(exp_id, "metrics.pkl")) if args.evaluate: print(val_save_dict) break
def main(args): torch.cuda.manual_seed_all(args.manual_seed) torch.manual_seed(args.manual_seed) np.random.seed(args.manual_seed) random.seed(args.manual_seed) # Initialize hosting dat_str = args.val_dataset now = datetime.now() exp_id = ( f"checkpoints/{dat_str}_mini{args.mini_factor}/" f"{now.year}_{now.month:02d}_{now.day:02d}/" f"{args.com}_frac{args.fraction}_mode{args.mode}_bs{args.batch_size}_" f"objs{args.obj_scale_factor}_objt{args.obj_trans_factor}" ) # Initialize local checkpoint folder save_args(args, exp_id, "opt") result_folder = os.path.join(exp_id, "results") os.makedirs(result_folder, exist_ok=True) pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}") with open(pyapt_path, "a") as t_f: t_f.write(" ") val_dataset, input_size = get_dataset.get_dataset( args.val_dataset, split=args.val_split, meta={"version": args.version, "split_mode": "paper"}, use_cache=args.use_cache, mini_factor=args.mini_factor, mode=args.mode, fraction=args.fraction, no_augm=True, center_idx=args.center_idx, scale_jittering=0, center_jittering=0, sample_nb=None, has_dist2strong=True, ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), drop_last=False, collate_fn=collate.meshreg_collate, ) opts = reloadmodel.load_opts(args.resume) model, epoch = reloadmodel.reload_model(args.resume, opts) if args.render_results: render_folder = os.path.join(exp_id, f"renders", f"epoch{epoch:04d}") os.makedirs(render_folder, exist_ok=True) print(f"Rendering to {render_folder}") else: render_folder = None img_folder = os.path.join(exp_id, "images", f"epoch{epoch:04d}") os.makedirs(img_folder, exist_ok=True) freeze.freeze_batchnorm_stats(model) # Freeze batchnorm fig = plt.figure(figsize=(12, 4)) save_results = {} save_results["opt"] = dict(vars(args)) save_results["val_losses"] = [] os.makedirs(args.json_folder, exist_ok=True) json_path = os.path.join(args.json_folder, f"{args.val_split}.json") evalpass.epoch_pass( val_loader, model, optimizer=None, scheduler=None, epoch=epoch, img_folder=img_folder, fig=fig, display_freq=args.display_freq, dump_results_path=json_path, render_folder=render_folder, render_freq=args.render_freq, true_root=args.true_root, ) print(f"Saved results for split {args.val_split} to {json_path}")
def main(args): setseeds.set_all_seeds(args.manual_seed) # Initialize hosting exp_id = f"checkpoints/{args.dataset}/" f"{args.com}" # Initialize local checkpoint folder print(f"Saving info about experiment at {exp_id}") save_args(args, exp_id, "opt") render_folder = os.path.join(exp_id, "images") os.makedirs(render_folder, exist_ok=True) # Load models models = [] for resume in args.resumes: print('Resuming', resume) opts = reloadmodel.load_opts(resume) model, epoch = reloadmodel.reload_model(resume, opts) models.append(model) freeze.freeze_batchnorm_stats(model) # Freeze batchnorm dataset, input_res = get_dataset.get_dataset(args.dataset, split=args.split, meta={"version": args.version, "split_mode": args.split_mode}, mode=args.mode, use_cache=args.use_cache, no_augm=True, center_idx=opts["center_idx"], sample_nb=None) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), drop_last=False, collate_fn=collate.meshreg_collate) all_samples = [] # Put models on GPU and evaluation mode for model in models: model.cuda() model.eval() i = 0 for batch in tqdm(loader): # Loop over batches all_results = [] # Compute model outputs with torch.no_grad(): for model in models: _, results, _ = model(batch) all_results.append(results) # Densely render error map for the meshes # for results in all_results: # render_results, cmap_obj = fastrender.comp_render( # batch, all_results, rotate=True, modes=("all", "obj", "hand"), max_val=args.max_val # ) # if i > 100: # break # i += 1 for img_idx, img in enumerate(batch[BaseQueries.IMAGE]): # Each batch has 4 images network_out = all_results[0] sample_idx = batch['idx'][img_idx] handpose, handtrans, handshape = dataset.pose_dataset.get_hand_info(sample_idx) sample_dict = dict() # sample_dict['image'] = img sample_dict['obj_faces'] = batch[BaseQueries.OBJFACES][img_idx, :, :] sample_dict['obj_verts_gt'] = batch[BaseQueries.OBJVERTS3D][img_idx, :, :] sample_dict['hand_faces'], _ = manoutils.get_closed_faces() sample_dict['hand_verts_gt'] = batch[BaseQueries.HANDVERTS3D][img_idx, :, :] sample_dict['obj_verts_pred'] = network_out['recov_objverts3d'][img_idx, :, :] sample_dict['hand_verts_pred'] = network_out['recov_handverts3d'][img_idx, :, :] # sample_dict['hand_adapt_trans'] = network_out['mano_adapt_trans'][img_idx, :] sample_dict['hand_pose_pred'] = network_out['pose'][img_idx, :] sample_dict['hand_beta_pred'] = network_out['shape'][img_idx, :] sample_dict['side'] = batch[BaseQueries.SIDE][img_idx] sample_dict['hand_pose_gt'] = handpose sample_dict['hand_beta_gt'] = handshape sample_dict['hand_trans_gt'] = handtrans sample_dict['hand_extr_gt'] = cam_extr = torch.Tensor([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) for k in sample_dict.keys(): sample_dict[k] = to_cpu_npy(sample_dict[k]) all_samples.append(sample_dict) continue # obj_verts_gt = samplevis.get_check_none(sample, BaseQueries.OBJVERTS3D, cpu=False) # hand_verts_gt = samplevis.get_check_none(sample, BaseQueries.HANDVERTS3D, cpu=False) # hand_verts = samplevis.get_check_none(results, "recov_handverts3d", cpu=False) # obj_verts = samplevis.get_check_none(results, "recov_objverts3d", cpu=False) # obj_faces = samplevis.get_check_none(sample, BaseQueries.OBJFACES, cpu=False).long() # hand_faces, _ = manoutils.get_closed_faces() # plt.imshow(img) # plt.show() for k in batch.keys(): elem = batch[k] # if isinstance(elem, list): # s = len(s) if isinstance(elem, torch.Tensor): s = elem.shape else: s = elem print('{}: Shape {}'.format(k, s)) for k in all_results[0].keys(): elem = all_results[0][k] if isinstance(elem, list): s = len(s) elif isinstance(elem, torch.Tensor): s = elem.shape else: s = elem print('Network out {}: Shape {}'.format(k, s)) fig = plt.figure() ax = fig.add_subplot(projection='3d') ax.plot_trisurf(sample_dict['obj_verts_gt'][:,0], sample_dict['obj_verts_gt'][:,1], sample_dict['obj_verts_gt'][:,2], triangles=sample_dict['obj_faces']) ax.plot_trisurf(sample_dict['hand_verts_gt'][:,0], sample_dict['hand_verts_gt'][:,1], sample_dict['hand_verts_gt'][:,2], triangles=sample_dict['hand_faces']) ax.plot_trisurf(sample_dict['obj_verts_pred'][:,0], sample_dict['obj_verts_pred'][:,1], sample_dict['obj_verts_pred'][:,2], triangles=sample_dict['obj_faces']) ax.plot_trisurf(sample_dict['hand_verts_pred'][:,0], sample_dict['hand_verts_pred'][:,1], sample_dict['hand_verts_pred'][:,2], triangles=sample_dict['hand_faces']) plt.show() print('Saving final dict', len(all_samples)) with open('all_samples.pkl', 'wb') as handle: pickle.dump(all_samples, handle) print('Done saving')
def main(args): torch.cuda.manual_seed_all(args.manual_seed) torch.manual_seed(args.manual_seed) np.random.seed(args.manual_seed) random.seed(args.manual_seed) # Initialize hosting dat_str = args.val_dataset now = datetime.now() exp_id = ( f"checkpoints/{dat_str}_mini{args.mini_factor}/" f"{now.year}_{now.month:02d}_{now.day:02d}/" f"{args.com}_frac{args.fraction}_mode{args.mode}_bs{args.batch_size}_" f"objs{args.obj_scale_factor}_objt{args.obj_trans_factor}") # Initialize local checkpoint folder save_args(args, exp_id, "opt") result_folder = os.path.join(exp_id, "results") os.makedirs(result_folder, exist_ok=True) pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}") with open(pyapt_path, "a") as t_f: t_f.write(" ") val_dataset, input_size = get_dataset.get_dataset( args.val_dataset, split=args.val_split, meta={ "version": args.version, "split_mode": args.split_mode }, use_cache=args.use_cache, mini_factor=args.mini_factor, mode=args.mode, fraction=args.fraction, no_augm=True, center_idx=args.center_idx, scale_jittering=0, center_jittering=0, sample_nb=None, has_dist2strong=True, ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), drop_last=False, collate_fn=collate.meshreg_collate, ) opts = reloadmodel.load_opts(args.resume) model, epoch = reloadmodel.reload_model(args.resume, opts) freeze.freeze_batchnorm_stats(model) # Freeze batchnorm all_samples = [] model.eval() model.cuda() for batch_idx, batch in enumerate(tqdm(val_loader)): all_results = [] with torch.no_grad(): loss, results, losses = model(batch) all_results.append(results) for img_idx, img in enumerate( batch[BaseQueries.IMAGE]): # Each batch has 4 images network_out = all_results[0] sample_idx = batch['idx'][img_idx] handpose, handtrans, handshape = val_dataset.pose_dataset.get_hand_info( sample_idx) sample_dict = dict() # sample_dict['image'] = img sample_dict['obj_faces'] = batch[BaseQueries.OBJFACES][ img_idx, :, :] sample_dict['obj_verts_gt'] = batch[BaseQueries.OBJVERTS3D][ img_idx, :, :] sample_dict['hand_faces'], _ = manoutils.get_closed_faces() sample_dict['hand_verts_gt'] = batch[BaseQueries.HANDVERTS3D][ img_idx, :, :] sample_dict['obj_verts_pred'] = network_out['recov_objverts3d'][ img_idx, :, :] sample_dict['hand_verts_pred'] = network_out['recov_handverts3d'][ img_idx, :, :] # sample_dict['hand_adapt_trans'] = network_out['mano_adapt_trans'][img_idx, :] sample_dict['hand_pose_pred'] = network_out['pose'][img_idx, :] sample_dict['hand_beta_pred'] = network_out['shape'][img_idx, :] sample_dict['side'] = batch[BaseQueries.SIDE][img_idx] sample_dict['hand_pose_gt'] = handpose sample_dict['hand_beta_gt'] = handshape sample_dict['hand_trans_gt'] = handtrans sample_dict['hand_extr_gt'] = torch.Tensor([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) for k in sample_dict.keys(): sample_dict[k] = to_cpu_npy(sample_dict[k]) all_samples.append(sample_dict) print('Saving final dict', len(all_samples)) with open('all_samples.pkl', 'wb') as handle: pickle.dump(all_samples, handle) print('Done saving')
def main(args): setseeds.set_all_seeds(args.manual_seed) train_dat_str = "_".join(args.train_datasets) dat_str = f"{train_dat_str}_warp{args.consist_dataset}" now = datetime.now() exp_id = ( f"checkpoints/{dat_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}/" f"{args.com}_frac{args.fraction:.1e}" f"_ld{args.lambda_data:.1e}_lc{args.lambda_consist:.1e}" f"crit{args.consist_criterion}sca{args.consist_scale}" f"opt{args.optimizer}_lr{args.lr}_crit{args.criterion2d}" f"_mom{args.momentum}_bs{args.batch_size}" f"_lmj3d{args.mano_lambda_joints3d:.1e}" f"_lmbeta{args.mano_lambda_shape:.1e}" f"_lmpr{args.mano_lambda_pose_reg:.1e}" f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}" f"_lmrw3d{args.mano_lambda_recov_verts3d:.1e}" f"_lov2d{args.obj_lambda_verts2d:.1e}_lov3d{args.obj_lambda_verts3d:.1e}" f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}" f"cj{args.center_jittering}seed{args.manual_seed}" f"sample_nb{args.sample_nb}_spac{args.spacing}" f"csteps{args.progressive_consist_steps}") if args.no_augm: exp_id = f"{exp_id}_no_augm" if args.no_consist_augm: exp_id = f"{exp_id}_noconsaugm" if args.block_rot: exp_id = f"{exp_id}_block_rot" if args.consist_gt_refs: exp_id = f"{exp_id}_gt_refs" # Initialize local checkpoint folder save_args(args, exp_id, "opt") monitor = Monitor(exp_id, hosting_folder=exp_id) img_folder = os.path.join(exp_id, "images") os.makedirs(img_folder, exist_ok=True) result_folder = os.path.join(exp_id, "results") result_path = os.path.join(result_folder, "results.pkl") os.makedirs(result_folder, exist_ok=True) pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}") with open(pyapt_path, "a") as t_f: t_f.write(" ") train_loader_nb = 0 if len(args.train_datasets): train_loader_nb = train_loader_nb + len(args.train_datasets) if args.consist_dataset is not None: train_loader_nb = train_loader_nb + 1 loaders = [] if len(args.train_datasets) is not None: for train_split, dat_name in zip(args.train_splits, args.train_datasets): train_dataset, input_res = get_dataset.get_dataset( dat_name, split=train_split, meta={ "version": args.version, "split_mode": "objects" }, block_rot=args.block_rot, max_rot=args.max_rot, center_idx=args.center_idx, center_jittering=args.center_jittering, fraction=args.fraction, mini_factor=args.mini_factor, mode="strong", no_augm=args.no_augm, scale_jittering=args.scale_jittering, sample_nb=1, use_cache=args.use_cache, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers / train_loader_nb), drop_last=True, collate_fn=collate_fn, ) loaders.append(train_loader) if args.consist_dataset is not None: consist_dataset, consist_input_res = get_dataset.get_dataset( args.consist_dataset, split=args.consist_split, meta={ "version": args.version, "split_mode": "objects" }, use_cache=args.use_cache, mini_factor=args.mini_factor, block_rot=args.block_rot, center_idx=args.center_idx, center_jittering=args.center_jittering, fraction=args.fraction, max_rot=args.max_rot, mode="full", no_augm=args.no_consist_augm, sample_nb=args.sample_nb, scale_jittering=args.scale_jittering, spacing=args. spacing, # Otherwise black padding gets included on warps ) print( f"Got consist dataset {args.consist_dataset} of size {len(consist_dataset)}" ) if input_res != consist_input_res: raise ValueError( f"train and consist dataset should have same input sizes" f"but got {input_res} and {consist_input_res}") consist_loader = torch.utils.data.DataLoader( consist_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers / train_loader_nb), drop_last=True, collate_fn=collate_fn, ) loaders.append(consist_loader) loader = concatloader.ConcatLoader(loaders) if args.eval_freq != -1: val_dataset, input_res = get_dataset.get_dataset( args.val_dataset, split=args.val_split, meta={ "version": args.version, "split_mode": "objects" }, use_cache=args.use_cache, mini_factor=args.mini_factor, no_augm=True, block_rot=args.block_rot, max_rot=args.max_rot, center_idx=args.center_idx, scale_jittering=args.scale_jittering, center_jittering=args.center_jittering, sample_nb=None, ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), drop_last=False, collate_fn=collate.meshreg_collate, ) model = MeshRegNet( mano_lambda_joints3d=args.mano_lambda_joints3d, mano_lambda_joints2d=args.mano_lambda_joints2d, mano_lambda_pose_reg=args.mano_lambda_pose_reg, mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d, mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d, mano_lambda_verts2d=args.mano_lambda_verts2d, mano_lambda_verts3d=args.mano_lambda_verts3d, mano_lambda_shape=args.mano_lambda_shape, mano_use_shape=args.mano_lambda_shape > 0, obj_lambda_verts2d=args.obj_lambda_verts2d, obj_lambda_verts3d=args.obj_lambda_verts3d, obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d, obj_trans_factor=args.obj_trans_factor, obj_scale_factor=args.obj_scale_factor, mano_fhb_hand="fhbhands" in args.train_datasets, ) model.cuda() # Initalize model if args.resume is not None: opts = reloadmodel.load_opts(args.resume) model, epoch = reloadmodel.reload_model(args.resume, opts) else: epoch = 0 if args.freeze_batchnorm: freeze.freeze_batchnorm_stats(model) # Freeze batchnorm model_params = filter(lambda p: p.requires_grad, model.parameters()) pmodel = warpreg.WarpRegNet( input_res, model, lambda_consist=args.lambda_consist, lambda_data=args.lambda_data, criterion=args.consist_criterion, consist_scale=args.consist_scale, gt_refs=args.consist_gt_refs, progressive_steps=args.progressive_consist_steps, use_backward=args.consist_use_backward, ) pmodel.cuda() if args.optimizer == "adam": optimizer = torch.optim.Adam(model_params, lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == "sgd": optimizer = torch.optim.SGD(model_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume is not None: reloadmodel.reload_optimizer(args.resume, optimizer) if args.lr_decay_gamma: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma) fig = plt.figure(figsize=(10, 10)) save_results = {} save_results["opt"] = dict(vars(args)) save_results["train_losses"] = [] save_results["val_losses"] = [] for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"): if not args.freeze_batchnorm: model.train() else: model.eval() save_dict, avg_meters, _ = epochpassconsist.epoch_pass( loader, model, train=True, optimizer=optimizer, scheduler=scheduler, epoch=epoch_idx, img_folder=img_folder, fig=fig, display_freq=args.display_freq, epoch_display_freq=args.epoch_display_freq, lr_decay_gamma=args.lr_decay_gamma, premodel=pmodel, loader_nb=train_loader_nb, ) monitor.log_train( epoch_idx + 1, {key: val.avg for key, val in avg_meters.average_meters.items()}) monitor.metrics.save_metrics(epoch_idx + 1, save_dict) monitor.metrics.plot_metrics() save_results["train_losses"].append(save_dict) with open(result_path, "wb") as p_f: pickle.dump(save_results, p_f) modelio.save_checkpoint( { "epoch": epoch_idx + 1, "network": "correspnet", "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler, }, is_best=True, checkpoint=exp_id, snapshot=args.snapshot, ) if args.eval_freq != -1 and epoch_idx % args.eval_freq == 0: val_save_dict, val_avg_meters, _ = epochpass.epoch_pass( val_loader, model, train=False, optimizer=None, scheduler=None, epoch=epoch_idx, img_folder=img_folder, fig=fig, display_freq=args.display_freq, epoch_display_freq=args.epoch_display_freq, lr_decay_gamma=args.lr_decay_gamma, ) save_results["val_losses"].append(val_save_dict) monitor.log_val(epoch_idx + 1, { key: val.avg for key, val in val_avg_meters.average_meters.items() }) monitor.metrics.save_metrics(epoch_idx + 1, val_save_dict) monitor.metrics.plot_metrics()
def main(args): setseeds.set_all_seeds(args.manual_seed) # Initialize hosting exp_id = f"checkpoints/{args.dataset}/" f"{args.com}" # Initialize local checkpoint folder print(f"Saving info about experiment at {exp_id}") save_args(args, exp_id, "opt") render_folder = os.path.join(exp_id, "images") os.makedirs(render_folder, exist_ok=True) # Load models models = [] for resume in args.resumes: opts = reloadmodel.load_opts(resume) model, epoch = reloadmodel.reload_model(resume, opts) models.append(model) freeze.freeze_batchnorm_stats(model) # Freeze batchnorm dataset, input_res = get_dataset.get_dataset( args.dataset, split=args.split, meta={}, mode=args.mode, use_cache=args.use_cache, no_augm=True, center_idx=opts["center_idx"], sample_nb=None, ) loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), drop_last=False, collate_fn=collate.meshreg_collate, ) model = MeshRegNet( mano_center_idx=opts["center_idx"], mano_lambda_joints2d=opts["mano_lambda_joints2d"], mano_lambda_joints3d=opts["mano_lambda_joints3d"], mano_lambda_recov_joints3d=opts["mano_lambda_recov_joints3d"], mano_lambda_recov_verts3d=opts["mano_lambda_recov_verts3d"], mano_lambda_verts2d=opts["mano_lambda_verts2d"], mano_lambda_verts3d=opts["mano_lambda_verts3d"], mano_lambda_shape=opts["mano_lambda_shape"], mano_use_shape=opts["mano_lambda_shape"] > 0, mano_lambda_pose_reg=opts["mano_lambda_pose_reg"], obj_lambda_recov_verts3d=opts["obj_lambda_recov_verts3d"], obj_lambda_verts2d=opts["obj_lambda_verts2d"], obj_lambda_verts3d=opts["obj_lambda_verts3d"], obj_trans_factor=opts["obj_trans_factor"], obj_scale_factor=opts["obj_scale_factor"], mano_fhb_hand="fhbhands" in args.dataset, ) fig = plt.figure(figsize=(10, 10)) save_results = {} save_results["opt"] = dict(vars(args)) # Put models on GPU and evaluation mode for model in models: model.cuda() model.eval() render_step = 0 for batch in tqdm(loader): all_results = [] # Compute model outputs with torch.no_grad(): for model in models: _, results, _ = model(batch) all_results.append(results) # Densely render error map for the meshes for results in all_results: render_results, cmap_obj = fastrender.comp_render( batch, all_results, rotate=True, modes=("all", "obj", "hand"), max_val=args.max_val) for img_idx, img in enumerate(batch[BaseQueries.IMAGE]): # Get rendered results for current image render_ress = [res[img_idx] for res in render_results["all"]] renderot_ress = [ res[img_idx] for res in render_results["all_rotated"] ] # Initialize figure fig.clf() row_nb = len(models) + 1 col_nb = 3 axes = fig.subplots(row_nb, col_nb) # Display cmap cmap = cm.get_cmap("jet") norm = matplotlib.colors.Normalize(vmin=0, vmax=1) cax = fig.add_axes([0.27, 0.05, 0.5, 0.02]) cb = matplotlib.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm, ticks=[0, 1], orientation="horizontal") cb.ax.set_xticklabels(["0", str(args.max_val * 100)]) cb.set_label("3D mesh error (cm)") # Get masks for hand and object in current image obj_masks = [ res.cpu()[img_idx][:, :].sum(2).numpy() for res in render_results["obj"] ] hand_masks = [ res.cpu()[img_idx][:, :].sum(2).numpy() for res in render_results["hand"] ] # Compute bounding boxes of masks crops = [ vizdemo.get_crop(render_res) for render_res in render_ress ] rot_crops = [ vizdemo.get_crop(renderot_res) for renderot_res in renderot_ress ] # Get crop that encompasses the spatial extent of all results crop = vizdemo.get_common_crop(crops) rot_crop = vizdemo.get_common_crop(rot_crops) for model_idx, (render_res, renderot_res) in enumerate( zip(render_ress, renderot_ress)): # Draw input image with predicted contours in column 1 ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 0) # Initialize white background and copy input image viz_img = 255 * img.new_ones(max(img.shape), max(img.shape), 3) viz_img[:img.shape[0], :img.shape[1], :3] = img # Clamp so that displayed image values are in [0, 255] render_res = render_res.clamp(0, 1) renderot_res = renderot_res.clamp(0, 1) obj_mask = obj_masks[model_idx] > 0 hand_mask = hand_masks[model_idx] > 0 # Draw hand and object contours contoured_img = vizdemo.draw_contours(viz_img.numpy(), hand_mask, color=(0, 210, 255)) contoured_img = vizdemo.draw_contours(contoured_img, obj_mask, color=(255, 50, 50)) ax.imshow(contoured_img[crop[0]:crop[2], crop[1]:crop[3]]) # Image with rendering overlay ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 1) ax.set_title(args.model_names[model_idx]) ax.imshow(viz_img[crop[0]:crop[2], crop[1]:crop[3]]) ax.imshow(render_res[crop[0]:crop[2], crop[1]:crop[3]], ) # Render rotated ax = vizdemo.get_axis(axes, row_nb, col_nb, model_idx, 2) ax.imshow(renderot_res[rot_crop[0]:rot_crop[2], rot_crop[1]:rot_crop[3]]) fig.tight_layout() save_path = os.path.join(render_folder, f"render{render_step:06d}.png") fig.savefig(save_path) print(f"Saved demo visualization at {save_path}") render_step += 1