def main(args): if ( not args.fine_tune or not args.fine_tune in ['rhd', 'stb'] ): raise Exception('expect --fine_tune in [rhd|stb], got {}' .format(args.fine_tune)) if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) args.datasets = [args.fine_tune, ] misc.print_args(args) auc_best = 0 print("\nCREATE NETWORK") model = models.SIKNet() model = model.to(device) criterion = losses.SIKLoss( lambda_quat=0.0, lambda_joint=1.0, lambda_shape=1.0 ) optimizer = torch.optim.Adam( [ { 'params': model.invk_layers.parameters(), 'initial_lr': args.learning_rate }, { 'params': model.shapereg_layers.parameters(), 'initial_lr': args.learning_rate }, ], lr=args.learning_rate, ) train_dataset = datasets.SIKOFFLINE( data_root=args.data_root, data_split="train", data_source=args.datasets, ) val_dataset = datasets.SIKOFFLINE( data_root=args.data_root, data_split="test", data_source=args.datasets, ) print("Total train dataset size: {}".format(len(train_dataset))) print("Total val dataset size: {}".format(len(val_dataset))) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True ) if args.evaluate or args.resume: misc.load_checkpoint( model, os.path.join( args.checkpoint, args.fine_tune, '{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune) ) ) if args.evaluate: validate(val_loader, model, criterion, args=args) cprint('Eval All Done', 'yellow', attrs=['bold']) return 0 model = torch.nn.DataParallel(model) print("\nUSING {} GPUs".format(torch.cuda.device_count())) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, args.lr_decay_step, gamma=args.gamma, last_epoch=args.start_epoch ) for epoch in range(args.start_epoch, args.epochs + 1): print('\nEpoch: %d' % (epoch)) for i in range(len(optimizer.param_groups)): print('group %d lr:' % i, optimizer.param_groups[i]['lr']) ############# trian for on epoch ############### train( train_loader, model, criterion, optimizer, args=args, ) ################################################## auc_all = validate(val_loader, model, criterion, args) misc.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.module.state_dict(), }, checkpoint=args.checkpoint, filename='{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune), snapshot=args.snapshot, is_best=auc_all > auc_best ) if auc_all > auc_best: auc_best = auc_all scheduler.step() cprint('All Done', 'yellow', attrs=['bold']) return 0 # end of main
def main(args): if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) print_args(args) print("\nCREATE NETWORK") model = models.SIKNet() model = model.to(device) criterion = losses.SIKLoss( lambda_quat=1.0, # only perform quaternion loss lambda_joint=0.0, lambda_shape=0.0) optimizer = torch.optim.Adam( [ { 'params': model.invk_layers.parameters(), 'initial_lr': args.learning_rate }, { 'params': model.shapereg_layers.parameters(), 'initial_lr': args.learning_rate }, ], lr=args.learning_rate, ) train_dataset = datasets.SIK1M(data_root=args.data_root, data_split="train") val_dataset = datasets.SIK1M(data_root=args.data_root, data_split="test") print("Total train dataset size: {}".format(len(train_dataset))) print("Total val dataset size: {}".format(len(val_dataset))) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate_siknet_synth_pth: misc.load_checkpoint(model, args.evaluate_siknet_synth_pth) for params in model.invk_layers.parameters(): params.requires_grad = False validate(val_loader, model, criterion, args=args) cprint('Eval All Done', 'yellow', attrs=['bold']) return 0 model = torch.nn.DataParallel(model) print("\nUSING {} GPUs".format(torch.cuda.device_count())) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.gamma, last_epoch=args.start_epoch) for epoch in range(args.start_epoch, args.epochs + 1): print('\nEpoch: %d' % (epoch + 1)) for i in range(len(optimizer.param_groups)): print('group %d lr:' % i, optimizer.param_groups[i]['lr']) ############# trian for on epoch ############### train( train_loader, model, criterion, optimizer, args=args, ) ################################################## misc.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.module.state_dict(), }, checkpoint=args.checkpoint, filename='{}.pth.tar'.format(args.saved_prefix), snapshot=args.snapshot, is_best=False) if epoch and epoch % 10 == 0: validate(val_loader, model, criterion, args) scheduler.step() cprint('All Done', 'yellow', attrs=['bold']) return 0 # end of main
def main(args): if (not args.fine_tune or not args.fine_tune in ['rhd', 'stb']): raise Exception('expect --fine_tune in [rhd|stb], got {}'.format( args.fine_tune)) if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) args.datasets = [ args.fine_tune, ] misc.print_args(args) auc_best = 0 print("\nCREATE NETWORK") model = models.NetBiHand( net_modules=['seed', 'lift', 'sik'], njoints=21, inp_res=256, out_hm_res=64, out_dep_res=64, upstream_hg_stacks=2, upstream_hg_blocks=1, ) model = model.to(device) criterion = losses.SIKLoss(lambda_quat=0.0, lambda_joint=1.0, lambda_shape=1.0) optimizer = torch.optim.Adam( [ { 'params': model.siknet.parameters(), 'initial_lr': args.learning_rate }, ], lr=args.learning_rate, ) train_dataset = HandDataset( data_split='train', train=True, scale_jittering=0.2, center_jettering=0.2, max_rot=0.5 * np.pi, subset_name=args.datasets, data_root=args.data_root, ) val_dataset = HandDataset( data_split='test', train=False, subset_name=args.datasets, data_root=args.data_root, ) print("Total train dataset size: {}".format(len(train_dataset))) print("Total val dataset size: {}".format(len(val_dataset))) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) model.load_checkpoints( ckp_seednet=os.path.join(args.checkpoint, 'ckp_seednet_all.pth.tar'), ckp_liftnet=os.path.join( args.checkpoint, args.fine_tune, 'ckp_liftnet_{}.pth.tar'.format(args.fine_tune)), ckp_siknet=os.path.join(args.checkpoint, 'ckp_siknet_synth.pth.tar')) for params in model.upstream.parameters(): params.requires_grad = False if args.evaluate or args.resume: model.load_checkpoints(ckp_siknet=os.path.join( args.checkpoint, args.fine_tune, 'ckp_siknet_{}.pth.tar'.format( args.fine_tune))) if args.evaluate: validate(val_loader, model, criterion, args=args) cprint('Eval All Done', 'yellow', attrs=['bold']) return 0 model = torch.nn.DataParallel(model) print("\nUSING {} GPUs".format(torch.cuda.device_count())) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.gamma, last_epoch=args.start_epoch) for epoch in range(args.start_epoch, args.epochs + 1): print('\nEpoch: %d' % (epoch)) for i in range(len(optimizer.param_groups)): print('group %d lr:' % i, optimizer.param_groups[i]['lr']) ############# trian for on epoch ############### train( train_loader, model, criterion, optimizer, args=args, ) ################################################## auc_all = validate(val_loader, model, criterion, args) misc.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.module.state_dict(), }, checkpoint=args.checkpoint, filename='{}_{}.pth.tar'.format(args.saved_prefix, args.fine_tune), snapshot=args.snapshot, is_best=auc_all > auc_best) if auc_all > auc_best: auc_best = auc_all scheduler.step() cprint('All Done', 'yellow', attrs=['bold']) return 0 # end of main