def main(args): setseeds.set_all_seeds(args.manual_seed) # Initialize hosting dat_str = "_".join(args.train_datasets) now = datetime.now() split_str = "_".join(args.train_splits) exp_id = ( f"checkpoints/{dat_str}_{split_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}/" #f"{args.com}_frac{args.fraction:.1e}" #f"lr{args.lr}_mom{args.momentum}_bs{args.batch_size}_" #f"_lmbeta{args.mano_lambda_shape:.1e}" #f"_lmpr{args.mano_lambda_pose_reg:.1e}" #f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}" #f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}" #f"seed{args.manual_seed}" ) # if args.no_augm: # exp_id = f"{exp_id}_no_augm" # if args.block_rot: # exp_id = f"{exp_id}_block_rot" # if args.freeze_batchnorm: # exp_id = f"{exp_id}_fbn" # Initialize local checkpoint folder print(f"Saving experiment logs, models, and training curves and images at {exp_id}") save_args(args, exp_id, "opt") img_folder = os.path.join(exp_id, "images") os.makedirs(img_folder, exist_ok=True) result_folder = os.path.join(exp_id, "results") result_path = os.path.join(result_folder, "results.pkl") os.makedirs(result_folder, exist_ok=True) pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}") with open(pyapt_path, "a") as t_f: t_f.write(" ") loaders = [] if not args.evaluate: for train_split, dat_name in zip(args.train_splits, args.train_datasets): train_dataset, _ = get_dataset.get_dataset( dat_name, split=train_split, meta={"version": args.version, "split_mode": args.split_mode}, use_cache=args.use_cache, mini_factor=args.mini_factor, no_augm=args.no_augm, block_rot=args.block_rot, max_rot=args.max_rot, center_idx=args.center_idx, scale_jittering=args.scale_jittering, center_jittering=args.center_jittering, fraction=args.fraction, mode="strong", sample_nb=None, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers / len(args.train_datasets)), drop_last=True, collate_fn=collate.meshreg_collate, ) loaders.append(train_loader) loader = concatloader.ConcatDataloader(loaders) if args.evaluate or args.eval_freq != -1: val_dataset, _ = get_dataset.get_dataset( args.val_dataset, split=args.val_split, meta={"version": args.version, "split_mode": args.split_mode}, use_cache=args.use_cache, mini_factor=args.mini_factor, no_augm=True, block_rot=args.block_rot, max_rot=args.max_rot, center_idx=args.center_idx, scale_jittering=args.scale_jittering, center_jittering=args.center_jittering, sample_nb=None, ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), drop_last=False, collate_fn=collate.meshreg_collate, ) model = MeshRegNet( mano_center_idx=args.center_idx, mano_lambda_joints2d=args.mano_lambda_joints2d, mano_lambda_joints3d=args.mano_lambda_joints3d, mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d, mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d, mano_lambda_verts2d=args.mano_lambda_verts2d, mano_lambda_verts3d=args.mano_lambda_verts3d, mano_lambda_shape=args.mano_lambda_shape, mano_use_shape=args.mano_lambda_shape > 0, mano_lambda_pose_reg=args.mano_lambda_pose_reg, obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d, obj_lambda_verts2d=args.obj_lambda_verts2d, obj_lambda_verts3d=args.obj_lambda_verts3d, obj_trans_factor=args.obj_trans_factor, obj_scale_factor=args.obj_scale_factor, mano_fhb_hand="fhbhands" in args.train_datasets, ) model.cuda() # Initalize model if args.resume is not None: opts = reloadmodel.load_opts(args.resume) model, epoch = reloadmodel.reload_model(args.resume, opts) model.cuda() if args.evaluate: args.epochs = epoch + 1 else: epoch = 0 if args.freeze_batchnorm: freeze.freeze_batchnorm_stats(model) # Freeze batchnorm model_params = filter(lambda p: p.requires_grad, model.parameters()) if args.optimizer == "adam": optimizer = torch.optim.Adam(model_params, lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == "sgd": optimizer = torch.optim.SGD( model_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay ) if args.resume is not None: reloadmodel.reload_optimizer(args.resume, optimizer) if args.lr_decay_gamma: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma) fig = plt.figure(figsize=(10, 10)) save_results = {} save_results["opt"] = dict(vars(args)) save_results["train_losses"] = [] save_results["val_losses"] = [] monitor = MetricMonitor() print("Will monitor metrics.") for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"): if not args.freeze_batchnorm: model.train() else: model.eval() if not args.evaluate: save_dict = epochpass.epoch_pass( loader, model, train=True, optimizer=optimizer, scheduler=scheduler, epoch=epoch_idx, img_folder=img_folder, fig=fig, display_freq=args.display_freq, epoch_display_freq=args.epoch_display_freq, lr_decay_gamma=args.lr_decay_gamma, monitor=monitor, ) save_results["train_losses"].append(save_dict) with open(result_path, "wb") as p_f: pickle.dump(save_results, p_f) modelio.save_checkpoint( { "epoch": epoch_idx + 1, "network": "correspnet", "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler, }, is_best=True, checkpoint=exp_id, snapshot=args.snapshot, ) if args.evaluate or (args.eval_freq != -1 and epoch_idx % args.eval_freq == 0): val_save_dict = epochpass.epoch_pass( val_loader, model, train=False, optimizer=None, scheduler=None, epoch=epoch_idx, img_folder=img_folder, fig=fig, display_freq=args.display_freq, epoch_display_freq=args.epoch_display_freq, lr_decay_gamma=args.lr_decay_gamma, monitor=monitor, ) save_results["val_losses"].append(val_save_dict) monitor.plot(os.path.join(exp_id, "training.html"), plotly=True) monitor.plot_histogram(os.path.join(exp_id, "evaluation.html"), plotly=True) if args.matplotlib: monitor.plot(result_folder, matplotlib=True) monitor.plot_histogram(result_folder, matplotlib=True) monitor.save_metrics(os.path.join(exp_id, "metrics.pkl")) if args.evaluate: print(val_save_dict) break
def main(args): setseeds.set_all_seeds(args.manual_seed) train_dat_str = "_".join(args.train_datasets) dat_str = f"{train_dat_str}_warp{args.consist_dataset}" now = datetime.now() exp_id = ( f"checkpoints/{dat_str}_mini{args.mini_factor}/{now.year}_{now.month:02d}_{now.day:02d}/" f"{args.com}_frac{args.fraction:.1e}" f"_ld{args.lambda_data:.1e}_lc{args.lambda_consist:.1e}" f"crit{args.consist_criterion}sca{args.consist_scale}" f"opt{args.optimizer}_lr{args.lr}_crit{args.criterion2d}" f"_mom{args.momentum}_bs{args.batch_size}" f"_lmj3d{args.mano_lambda_joints3d:.1e}" f"_lmbeta{args.mano_lambda_shape:.1e}" f"_lmpr{args.mano_lambda_pose_reg:.1e}" f"_lmrj3d{args.mano_lambda_recov_joints3d:.1e}" f"_lmrw3d{args.mano_lambda_recov_verts3d:.1e}" f"_lov2d{args.obj_lambda_verts2d:.1e}_lov3d{args.obj_lambda_verts3d:.1e}" f"_lovr3d{args.obj_lambda_recov_verts3d:.1e}" f"cj{args.center_jittering}seed{args.manual_seed}" f"sample_nb{args.sample_nb}_spac{args.spacing}" f"csteps{args.progressive_consist_steps}") if args.no_augm: exp_id = f"{exp_id}_no_augm" if args.no_consist_augm: exp_id = f"{exp_id}_noconsaugm" if args.block_rot: exp_id = f"{exp_id}_block_rot" if args.consist_gt_refs: exp_id = f"{exp_id}_gt_refs" # Initialize local checkpoint folder save_args(args, exp_id, "opt") monitor = Monitor(exp_id, hosting_folder=exp_id) img_folder = os.path.join(exp_id, "images") os.makedirs(img_folder, exist_ok=True) result_folder = os.path.join(exp_id, "results") result_path = os.path.join(result_folder, "results.pkl") os.makedirs(result_folder, exist_ok=True) pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}") with open(pyapt_path, "a") as t_f: t_f.write(" ") train_loader_nb = 0 if len(args.train_datasets): train_loader_nb = train_loader_nb + len(args.train_datasets) if args.consist_dataset is not None: train_loader_nb = train_loader_nb + 1 loaders = [] if len(args.train_datasets) is not None: for train_split, dat_name in zip(args.train_splits, args.train_datasets): train_dataset, input_res = get_dataset.get_dataset( dat_name, split=train_split, meta={ "version": args.version, "split_mode": "objects" }, block_rot=args.block_rot, max_rot=args.max_rot, center_idx=args.center_idx, center_jittering=args.center_jittering, fraction=args.fraction, mini_factor=args.mini_factor, mode="strong", no_augm=args.no_augm, scale_jittering=args.scale_jittering, sample_nb=1, use_cache=args.use_cache, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers / train_loader_nb), drop_last=True, collate_fn=collate_fn, ) loaders.append(train_loader) if args.consist_dataset is not None: consist_dataset, consist_input_res = get_dataset.get_dataset( args.consist_dataset, split=args.consist_split, meta={ "version": args.version, "split_mode": "objects" }, use_cache=args.use_cache, mini_factor=args.mini_factor, block_rot=args.block_rot, center_idx=args.center_idx, center_jittering=args.center_jittering, fraction=args.fraction, max_rot=args.max_rot, mode="full", no_augm=args.no_consist_augm, sample_nb=args.sample_nb, scale_jittering=args.scale_jittering, spacing=args. spacing, # Otherwise black padding gets included on warps ) print( f"Got consist dataset {args.consist_dataset} of size {len(consist_dataset)}" ) if input_res != consist_input_res: raise ValueError( f"train and consist dataset should have same input sizes" f"but got {input_res} and {consist_input_res}") consist_loader = torch.utils.data.DataLoader( consist_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers / train_loader_nb), drop_last=True, collate_fn=collate_fn, ) loaders.append(consist_loader) loader = concatloader.ConcatLoader(loaders) if args.eval_freq != -1: val_dataset, input_res = get_dataset.get_dataset( args.val_dataset, split=args.val_split, meta={ "version": args.version, "split_mode": "objects" }, use_cache=args.use_cache, mini_factor=args.mini_factor, no_augm=True, block_rot=args.block_rot, max_rot=args.max_rot, center_idx=args.center_idx, scale_jittering=args.scale_jittering, center_jittering=args.center_jittering, sample_nb=None, ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), drop_last=False, collate_fn=collate.meshreg_collate, ) model = MeshRegNet( mano_lambda_joints3d=args.mano_lambda_joints3d, mano_lambda_joints2d=args.mano_lambda_joints2d, mano_lambda_pose_reg=args.mano_lambda_pose_reg, mano_lambda_recov_joints3d=args.mano_lambda_recov_joints3d, mano_lambda_recov_verts3d=args.mano_lambda_recov_verts3d, mano_lambda_verts2d=args.mano_lambda_verts2d, mano_lambda_verts3d=args.mano_lambda_verts3d, mano_lambda_shape=args.mano_lambda_shape, mano_use_shape=args.mano_lambda_shape > 0, obj_lambda_verts2d=args.obj_lambda_verts2d, obj_lambda_verts3d=args.obj_lambda_verts3d, obj_lambda_recov_verts3d=args.obj_lambda_recov_verts3d, obj_trans_factor=args.obj_trans_factor, obj_scale_factor=args.obj_scale_factor, mano_fhb_hand="fhbhands" in args.train_datasets, ) model.cuda() # Initalize model if args.resume is not None: opts = reloadmodel.load_opts(args.resume) model, epoch = reloadmodel.reload_model(args.resume, opts) else: epoch = 0 if args.freeze_batchnorm: freeze.freeze_batchnorm_stats(model) # Freeze batchnorm model_params = filter(lambda p: p.requires_grad, model.parameters()) pmodel = warpreg.WarpRegNet( input_res, model, lambda_consist=args.lambda_consist, lambda_data=args.lambda_data, criterion=args.consist_criterion, consist_scale=args.consist_scale, gt_refs=args.consist_gt_refs, progressive_steps=args.progressive_consist_steps, use_backward=args.consist_use_backward, ) pmodel.cuda() if args.optimizer == "adam": optimizer = torch.optim.Adam(model_params, lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == "sgd": optimizer = torch.optim.SGD(model_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume is not None: reloadmodel.reload_optimizer(args.resume, optimizer) if args.lr_decay_gamma: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma) fig = plt.figure(figsize=(10, 10)) save_results = {} save_results["opt"] = dict(vars(args)) save_results["train_losses"] = [] save_results["val_losses"] = [] for epoch_idx in tqdm(range(epoch, args.epochs), desc="epoch"): if not args.freeze_batchnorm: model.train() else: model.eval() save_dict, avg_meters, _ = epochpassconsist.epoch_pass( loader, model, train=True, optimizer=optimizer, scheduler=scheduler, epoch=epoch_idx, img_folder=img_folder, fig=fig, display_freq=args.display_freq, epoch_display_freq=args.epoch_display_freq, lr_decay_gamma=args.lr_decay_gamma, premodel=pmodel, loader_nb=train_loader_nb, ) monitor.log_train( epoch_idx + 1, {key: val.avg for key, val in avg_meters.average_meters.items()}) monitor.metrics.save_metrics(epoch_idx + 1, save_dict) monitor.metrics.plot_metrics() save_results["train_losses"].append(save_dict) with open(result_path, "wb") as p_f: pickle.dump(save_results, p_f) modelio.save_checkpoint( { "epoch": epoch_idx + 1, "network": "correspnet", "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler, }, is_best=True, checkpoint=exp_id, snapshot=args.snapshot, ) if args.eval_freq != -1 and epoch_idx % args.eval_freq == 0: val_save_dict, val_avg_meters, _ = epochpass.epoch_pass( val_loader, model, train=False, optimizer=None, scheduler=None, epoch=epoch_idx, img_folder=img_folder, fig=fig, display_freq=args.display_freq, epoch_display_freq=args.epoch_display_freq, lr_decay_gamma=args.lr_decay_gamma, ) save_results["val_losses"].append(val_save_dict) monitor.log_val(epoch_idx + 1, { key: val.avg for key, val in val_avg_meters.average_meters.items() }) monitor.metrics.save_metrics(epoch_idx + 1, val_save_dict) monitor.metrics.plot_metrics()