def main(args): if (not args.fine_tune or not args.fine_tune in ['rhd', 'stb']): raise Exception('expect --fine_tune in [rhd|stb], got {}'.format( args.fine_tune)) if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) args.datasets = [ args.fine_tune, ] misc.print_args(args) auc_best = 0 print("\nCREATE NETWORK") model = models.NetBiHand( net_modules=['seed', 'lift', 'sik'], njoints=21, inp_res=256, out_hm_res=64, out_dep_res=64, upstream_hg_stacks=2, upstream_hg_blocks=1, ) model = model.to(device) criterion = losses.SIKLoss(lambda_quat=0.0, lambda_joint=1.0, lambda_shape=1.0) optimizer = torch.optim.Adam( [ { 'params': model.siknet.parameters(), 'initial_lr': args.learning_rate }, ], lr=args.learning_rate, ) train_dataset = HandDataset( data_split='train', train=True, scale_jittering=0.2, center_jettering=0.2, max_rot=0.5 * np.pi, subset_name=args.datasets, data_root=args.data_root, ) val_dataset = HandDataset( data_split='test', train=False, subset_name=args.datasets, data_root=args.data_root, ) print("Total train dataset size: {}".format(len(train_dataset))) print("Total val dataset size: {}".format(len(val_dataset))) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) model.load_checkpoints( ckp_seednet=os.path.join(args.checkpoint, 'ckp_seednet_all.pth.tar'), ckp_liftnet=os.path.join( args.checkpoint, args.fine_tune, 'ckp_liftnet_{}.pth.tar'.format(args.fine_tune)), ckp_siknet=os.path.join(args.checkpoint, 'ckp_siknet_synth.pth.tar')) for params in model.upstream.parameters(): params.requires_grad = False if args.evaluate or args.resume: model.load_checkpoints(ckp_siknet=os.path.join( args.checkpoint, args.fine_tune, 'ckp_siknet_{}.pth.tar'.format( args.fine_tune))) if args.evaluate: validate(val_loader, model, criterion, args=args) cprint('Eval All Done', 'yellow', attrs=['bold']) return 0 model = torch.nn.DataParallel(model) print("\nUSING {} GPUs".format(torch.cuda.device_count())) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.gamma, last_epoch=args.start_epoch) for epoch in range(args.start_epoch, args.epochs + 1): print('\nEpoch: %d' % (epoch)) for i in range(len(optimizer.param_groups)): print('group %d lr:' % i, optimizer.param_groups[i]['lr']) ############# trian for on epoch ############### train( train_loader, model, criterion, optimizer, args=args, ) ################################################## auc_all = validate(val_loader, model, criterion, args) misc.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.module.state_dict(), }, checkpoint=args.checkpoint, filename='{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune), snapshot=args.snapshot, is_best=auc_all > auc_best) if auc_all > auc_best: auc_best = auc_all scheduler.step() cprint('All Done', 'yellow', attrs=['bold']) return 0 # end of main
def main(args): if ( not args.fine_tune or not args.fine_tune in ['rhd', 'stb'] ): raise Exception('expect --fine_tune in [rhd|stb], got {}' .format(args.fine_tune)) if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) args.datasets = [args.fine_tune, ] misc.print_args(args) auc_best = 0 print("\nCREATE NETWORK") model = models.SIKNet() model = model.to(device) criterion = losses.SIKLoss( lambda_quat=0.0, lambda_joint=1.0, lambda_shape=1.0 ) optimizer = torch.optim.Adam( [ { 'params': model.invk_layers.parameters(), 'initial_lr': args.learning_rate }, { 'params': model.shapereg_layers.parameters(), 'initial_lr': args.learning_rate }, ], lr=args.learning_rate, ) train_dataset = datasets.SIKOFFLINE( data_root=args.data_root, data_split="train", data_source=args.datasets, ) val_dataset = datasets.SIKOFFLINE( data_root=args.data_root, data_split="test", data_source=args.datasets, ) print("Total train dataset size: {}".format(len(train_dataset))) print("Total val dataset size: {}".format(len(val_dataset))) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True ) if args.evaluate or args.resume: misc.load_checkpoint( model, os.path.join( args.checkpoint, args.fine_tune, '{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune) ) ) if args.evaluate: validate(val_loader, model, criterion, args=args) cprint('Eval All Done', 'yellow', attrs=['bold']) return 0 model = torch.nn.DataParallel(model) print("\nUSING {} GPUs".format(torch.cuda.device_count())) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, args.lr_decay_step, gamma=args.gamma, last_epoch=args.start_epoch ) for epoch in range(args.start_epoch, args.epochs + 1): print('\nEpoch: %d' % (epoch)) for i in range(len(optimizer.param_groups)): print('group %d lr:' % i, optimizer.param_groups[i]['lr']) ############# trian for on epoch ############### train( train_loader, model, criterion, optimizer, args=args, ) ################################################## auc_all = validate(val_loader, model, criterion, args) misc.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.module.state_dict(), }, checkpoint=args.checkpoint, filename='{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune), snapshot=args.snapshot, is_best=auc_all > auc_best ) if auc_all > auc_best: auc_best = auc_all scheduler.step() cprint('All Done', 'yellow', attrs=['bold']) return 0 # end of main
def main(args): best_acc = 0 if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) misc.print_args(args) print("\nCREATE NETWORK") model = models.NetBiHand( net_modules=args.net_modules, # only train hm njoints=21, inp_res=256, out_hm_res=64, out_dep_res=64, upstream_hg_stacks=2, upstream_hg_blocks=1, ) model = model.to(device) # define loss function (criterion) and optimizer criterion_ups = losses.UpstreamLoss(lambda_hm=100.0, lambda_mask=1.0) criterion = { 'ups': criterion_ups, } optimizer = torch.optim.Adam( [ { 'params': model.upstream.seednet.parameters(), 'initial_lr': args.learning_rate }, ], lr=args.learning_rate, ) print("\nCREATE DATASET") train_dataset = HandDataset( data_split='train', train=True, scale_jittering=0.2, center_jettering=0.2, max_rot=0.5 * np.pi, subset_name=args.datasets, data_root=args.data_root, ) val_dataset = HandDataset( data_split='test', train=False, subset_name=args.datasets, data_root=args.data_root, ) print("Total train dataset size: {}".format(len(train_dataset))) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=True, num_workers=args.workers, pin_memory=True) print("Total test dataset size: {}".format(len(val_dataset))) print("\nLOAD CHECKPOINT") if args.resume_seednet_pth: model.load_checkpoints(ckp_seednet=args.resume_seednet_pth) else: for m in model.modules(): if isinstance(m, torch.nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight) if args.evaluate_seednet_pth: model.load_checkpoints(ckp_seednet=args.evaluate_seednet_pth) validate(val_loader, model, criterion, args=args) return 0 model = torch.nn.DataParallel(model) print("\nUSING {} GPUs".format(torch.cuda.device_count())) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.gamma, last_epoch=args.start_epoch) for epoch in range(args.start_epoch, args.epochs + 1): print('\nEpoch: %d' % (epoch + 1)) for i in range(len(optimizer.param_groups)): print('group %d lr:' % i, optimizer.param_groups[i]['lr']) ############# trian for on epoch ############### train( train_loader, model, criterion, optimizer, args=args, ) ################################################## acc_hm = best_acc if epoch >= 50 and epoch % 5 == 0: acc_hm = validate(val_loader, model, criterion, args=args) misc.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.module.upstream.seednet.state_dict(), }, checkpoint=args.checkpoint, filename='{}.pth.tar'.format(args.saved_prefix), snapshot=args.snapshot, is_best=acc_hm > best_acc) if acc_hm > best_acc: best_acc = acc_hm scheduler.step() cprint('All Done', 'yellow', attrs=['bold']) return 0 # end of main