def main(args): print('==> Using settings {}'.format(args)) device = torch.device("cuda") print('==> Loading dataset...') data_dict = data_preparation(args) print("==> Creating PoseNet model...") model_pos = model_pos_preparation(args, data_dict['dataset'], device) model_pos_eval = model_pos_preparation(args, data_dict['dataset'], device) # used for evaluation only # prepare optimizer for posenet posenet_optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr_p) posenet_lr_scheduler = get_scheduler(posenet_optimizer, policy='lambda', nepoch_fix=0, nepoch=args.epochs) print("==> Creating PoseAug model...") poseaug_dict = get_poseaug_model(args, data_dict['dataset']) # loss function criterion = nn.MSELoss(reduction='mean').to(device) # GAN trick: data buffer for fake data fake_3d_sample = Sample_from_Pool() fake_2d_sample = Sample_from_Pool() args.checkpoint = path.join( args.checkpoint, args.posenet_name, args.keypoints, datetime.datetime.now().isoformat() + '_' + args.note) os.makedirs(args.checkpoint, exist_ok=True) print('==> Making checkpoint dir: {}'.format(args.checkpoint)) logger = Logger(os.path.join(args.checkpoint, 'log.txt'), args) logger.record_args(str(model_pos)) logger.set_names([ 'epoch', 'lr', 'error_h36m_p1', 'error_h36m_p2', 'error_3dhp_p1', 'error_3dhp_p2' ]) # Init monitor for net work training ######################################################### summary = Summary(args.checkpoint) writer = summary.create_summary() ########################################################## # start training ########################################################## start_epoch = 0 dhpp1_best = None s911p1_best = None for _ in range(start_epoch, args.epochs): if summary.epoch == 0: # evaluate the pre-train model for epoch 0. h36m_p1, h36m_p2, dhp_p1, dhp_p2 = evaluate_posenet(args, data_dict, model_pos, model_pos_eval, device, summary, writer, tag='_fake') h36m_p1, h36m_p2, dhp_p1, dhp_p2 = evaluate_posenet(args, data_dict, model_pos, model_pos_eval, device, summary, writer, tag='_real') summary.summary_epoch_update() # update train loader dataloader_update(args=args, data_dict=data_dict, device=device) # Train for one epoch train_gan(args, poseaug_dict, data_dict, model_pos, criterion, fake_3d_sample, fake_2d_sample, summary, writer) if summary.epoch > args.warmup: train_posenet(model_pos, data_dict['train_fake2d3d_loader'], posenet_optimizer, criterion, device) h36m_p1, h36m_p2, dhp_p1, dhp_p2 = evaluate_posenet(args, data_dict, model_pos, model_pos_eval, device, summary, writer, tag='_fake') train_posenet(model_pos, data_dict['train_det2d3d_loader'], posenet_optimizer, criterion, device) h36m_p1, h36m_p2, dhp_p1, dhp_p2 = evaluate_posenet(args, data_dict, model_pos, model_pos_eval, device, summary, writer, tag='_real') # Update learning rates ######################## poseaug_dict['scheduler_G'].step() poseaug_dict['scheduler_d3d'].step() poseaug_dict['scheduler_d2d'].step() posenet_lr_scheduler.step() lr_now = posenet_optimizer.param_groups[0]['lr'] print('\nEpoch: %d | LR: %.8f' % (summary.epoch, lr_now)) # Update log file logger.append( [summary.epoch, lr_now, h36m_p1, h36m_p2, dhp_p1, dhp_p2]) # Update checkpoint if dhpp1_best is None or dhpp1_best > dhp_p1: dhpp1_best = dhp_p1 logger.record_args( "==> Saving checkpoint at epoch '{}', with dhp_p1 {}".format( summary.epoch, dhpp1_best)) save_ckpt( { 'epoch': summary.epoch, 'model_pos': model_pos.state_dict() }, args.checkpoint, suffix='best_dhp_p1') if s911p1_best is None or s911p1_best > h36m_p1: s911p1_best = h36m_p1 logger.record_args( "==> Saving checkpoint at epoch '{}', with s911p1 {}".format( summary.epoch, s911p1_best)) save_ckpt( { 'epoch': summary.epoch, 'model_pos': model_pos.state_dict() }, args.checkpoint, suffix='best_h36m_p1') summary.summary_epoch_update() writer.close() logger.close()
def main(args): print('==> Using settings {}'.format(args)) device = torch.device("cuda") print('==> Loading dataset...') data_dict = data_preparation(args) print("==> Creating PoseNet model...") model_pos = model_pos_preparation(args, data_dict['dataset'], device) print("==> Prepare optimizer...") criterion = nn.MSELoss(reduction='mean').to(device) optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr) ckpt_dir_path = path.join( args.checkpoint, args.posenet_name, args.keypoints, datetime.datetime.now().strftime('%m%d%H%M%S') + '_' + args.note) os.makedirs(ckpt_dir_path, exist_ok=True) print('==> Making checkpoint dir: {}'.format(ckpt_dir_path)) logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'), args) logger.set_names([ 'epoch', 'lr', 'loss_train', 'error_h36m_p1', 'error_h36m_p2', 'error_3dhp_p1', 'error_3dhp_p2' ]) ################################################# # ########## start training here ################################################# start_epoch = 0 error_best = None glob_step = 0 lr_now = args.lr for epoch in range(start_epoch, args.epochs): print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now)) # Train for one epoch epoch_loss, lr_now, glob_step = train(data_dict['train_loader'], model_pos, criterion, optimizer, device, args.lr, lr_now, glob_step, args.lr_decay, args.lr_gamma, max_norm=args.max_norm) # Evaluate error_h36m_p1, error_h36m_p2 = evaluate(data_dict['H36M_test'], model_pos, device) error_3dhp_p1, error_3dhp_p2 = evaluate(data_dict['3DHP_test'], model_pos, device, flipaug='_flip') # Update log file logger.append([ epoch + 1, lr_now, epoch_loss, error_h36m_p1, error_h36m_p2, error_3dhp_p1, error_3dhp_p2 ]) # Update checkpoint if error_best is None or error_best > error_h36m_p1: error_best = error_h36m_p1 save_ckpt( { 'state_dict': model_pos.state_dict(), 'epoch': epoch + 1 }, ckpt_dir_path, suffix='best') if (epoch + 1) % args.snapshot == 0: save_ckpt( { 'state_dict': model_pos.state_dict(), 'epoch': epoch + 1 }, ckpt_dir_path) logger.close() logger.plot(['loss_train', 'error_h36m_p1']) savefig(path.join(ckpt_dir_path, 'log.eps')) return