def save_checkpoint(model, optimizer, epoch, global_step, args): #SAVE save_dir = model_utils.make_joint_checkpoint_name(args, epoch) save_dir = os.path.join(args.savemodel, save_dir) if not os.path.exists(save_dir): os.makedirs(save_dir) model_path = os.path.join(save_dir, 'model_{:04d}.pth'.format(epoch)) if epoch % args.save_freq == 0: torch.save({ 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch }, model_path) print('<=== checkpoint has been saved to {}.'.format(model_path))
def main(args): train_loader, test_loader = make_data_loader(args) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) if args.resnet_arch is None: model = UNet() else: model = ResNetUNet(args.resnet_arch) # model = DataParallelWithCallback(model) model = nn.DataParallel(model).cuda() print('#parameters in warp disp model: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0004 ) if args.loadmodel is not None: state_dict = torch.load(args.loadmodel)['state_dict'] # if LooseVersion(torch.__version__) >= LooseVersion('0.4.0'): # keys = list(state_dict.keys()) # for k in keys: # if k.find('num_batches_tracked') >= 0: # state_dict.pop(k) model.load_state_dict(state_dict) print('==> A pre-trained checkpoint has been loaded: {}.'.format(args.loadmodel)) start_epoch = 1 if args.auto_resume: raise NotImplementedError # search for the latest saved checkpoint epoch_found = -1 for epoch in range(args.epochs+1, 1, -1): ckpt_path = model_utils.make_joint_checkpoint_name(args, epoch) ckpt_path = os.path.join(args.savemodel, ckpt_path) if os.path.exists(ckpt_path): epoch_found = epoch break if epoch_found > 0: ckpt = torch.load(ckpt_path) assert ckpt['epoch'] == epoch_found, [ckpt['epoch'], epoch_found] start_epoch = ckpt['epoch'] + 1 optimizer.load_state_dict(ckpt['optimizer']) model.load_state_dict(ckpt['state_dict']) print('==> Automatically resumed training from {}.'.format(ckpt_path)) crit = multiscaleloss( downsample_factors=(16, 8, 4, 2, 1), weights=(1, 1, 2, 4, 8), loss='l1', size_average=True ).cuda() start_full_time = time.time() train_print_format = '{}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}'\ '\t{:.6f}' test_print_format = '{}\t{:d}\t{:d}\t{:.3f}\t{:.2f}\t{:.3f}\t{:.2f}'\ '\t{:.6f}' os.makedirs(os.path.join(args.savemodel, 'tensorboard'), exist_ok=True) writer = SummaryWriter(os.path.join(args.savemodel, 'tensorboard')) global_step = 0 for epoch in range(start_epoch, args.epochs+1): total_err = 0 total_test_err_pct = 0 total_test_loss = 0 lr = adjust_learning_rate(optimizer, epoch, len(train_loader)) ## training ## start_time = time.time() for batch_idx, data in enumerate(train_loader): end = time.time() loss, losses = train( model, crit, optimizer, data ) global_step += 1 writer.add_scalar('train/total_loss', loss * 20, global_step) if (batch_idx + 1) % args.print_freq == 0: print(train_print_format.format( 'Train', global_step, epoch, batch_idx, len(train_loader), loss, end - start_time, time.time() - start_time, lr )) sys.stdout.flush() start_time = time.time() ## test ## start_time = time.time() for batch_idx, batch_data in enumerate(test_loader): err, err_pct, loss = test_disp( model, crit, batch_data, args.cmd ) total_err += err total_test_err_pct += err_pct total_test_loss += loss writer.add_scalar('test/loss', total_test_loss / (len(test_loader) + 1e-30) * 20, epoch) writer.add_scalar('test/err', total_err / (len(test_loader) + 1e-30), epoch) writer.add_scalar('test/err_pct', total_test_err_pct / (len(test_loader) + 1e-30) * 100, epoch) print(test_print_format.format( 'Test', global_step, epoch, total_err/(len(test_loader) + 1e-30), total_test_err_pct/(len(test_loader) + 1e-30) * 100, total_test_loss / (len(test_loader) + 1e-30), time.time() - start_time, lr )) sys.stdout.flush() save_checkpoint(model, optimizer, epoch, global_step, args) print('full time = %.2f HR' %((time.time() - start_full_time)/3600))
def main(args): train_loader, flow_test_loader, disp_test_loader = make_data_loader(args) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) model = model_utils.make_model( args, do_flow=not args.no_flow, do_disp=not args.no_disp, do_seg=(args.do_seg or args.do_seg_distill) ) print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()])) ) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0004 ) if args.loadmodel is not None: ckpt = torch.load(args.loadmodel) state_dict = ckpt['state_dict'] model.load_state_dict(model_utils.patch_model_state_dict(state_dict)) print('==> A pre-trained checkpoint has been loaded.') start_epoch = 1 if args.auto_resume: # search for the latest saved checkpoint epoch_found = -1 for epoch in range(args.epochs+1, 1, -1): ckpt_dir = model_utils.make_joint_checkpoint_name(args, epoch) ckpt_dir = os.path.join(args.savemodel, ckpt_dir) ckpt_path = os.path.join(ckpt_dir, 'model_{:04d}.pth'.format(epoch)) if os.path.exists(ckpt_path): epoch_found = epoch break if epoch_found > 0: ckpt = torch.load(ckpt_path) assert ckpt['epoch'] == epoch_found, [ckpt['epoch'], epoch_found] start_epoch = ckpt['epoch'] + 1 optimizer.load_state_dict(ckpt['optimizer']) model.load_state_dict(ckpt['state_dict']) print('==> Automatically resumed training from {}.'.format(ckpt_path)) else: if args.resume is not None: ckpt = torch.load(args.resume) start_epoch = ckpt['epoch'] + 1 optimizer.load_state_dict(ckpt['optimizer']) model.load_state_dict(ckpt['state_dict']) print('==> Manually resumed training from {}.'.format(args.resume)) cudnn.benchmark = True (flow_crit, flow_occ_crit), flow_down_scales, flow_weights = model_utils.make_flow_criteria(args) (disp_crit, disp_occ_crit), disp_down_scales, disp_weights = model_utils.make_disp_criteria(args) hard_seg_crit = None soft_seg_crit = None self_supervised_crit = None criteria = ( disp_crit, disp_occ_crit, flow_crit, flow_occ_crit ) min_loss=100000000000000000 min_epo=0 min_err_pct = 10000 start_full_time = time.time() train_print_format = '{}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}'\ '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}' test_print_format = '{}\t{:d}\t{:d}\t{:.3f}\t{:.2f}\t{:.3f}\t{:.2f}\t{:.2f}\t{:.2f}'\ '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}' global_step = 0 for epoch in range(start_epoch, args.epochs+1): total_train_loss = 0 total_err = 0 total_test_err_pct = 0 total_disp_occ_acc = 0 total_epe = 0 total_flow_occ_acc = 0 total_seg_acc = 0 lr = adjust_learning_rate(optimizer, epoch, len(train_loader)) ## training ## start_time = time.time() for batch_idx, batch_data in enumerate(train_loader): end = time.time() train_res = train(model, optimizer, batch_data, criteria, args) loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss = train_res global_step += 1 if (batch_idx + 1) % args.print_freq == 0: print(train_print_format.format( 'Train', global_step, epoch, batch_idx, len(train_loader), loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss, end - start_time, time.time() - start_time, lr )) sys.stdout.flush() start_time = time.time() total_train_loss += loss # should have used the validation set to select the best model start_time = time.time() for batch_idx, batch_data in enumerate(flow_test_loader): loss_data = test_flow( model, batch_data, criteria, args.cmd, flow_down_scales[0] ) epe, flow_occ_acc, loss, flow_loss, flow_occ_loss = loss_data total_epe += epe total_flow_occ_acc += flow_occ_acc for batch_idx, batch_data in enumerate(disp_test_loader): loss_data = test_disp( model, batch_data, criteria, args.cmd ) err, err_pct, disp_occ_acc, loss, disp_loss, disp_occ_loss = loss_data total_err += err total_test_err_pct += err_pct total_disp_occ_acc += disp_occ_acc if total_test_err_pct/len(disp_test_loader) * 100 < min_err_pct: min_loss = total_err/len(disp_test_loader) min_epo = epoch min_err_pct = total_test_err_pct/len(disp_test_loader) * 100 print(test_print_format.format( 'Test', global_step, epoch, total_epe / len(flow_test_loader) * args.div_flow, total_flow_occ_acc / len(flow_test_loader) * 100, total_err/len(disp_test_loader), total_test_err_pct/len(disp_test_loader) * 100, total_disp_occ_acc / len(disp_test_loader) * 100, flow_loss, flow_occ_loss, disp_loss * args.disp_loss_weight, disp_occ_loss * args.disp_loss_weight, time.time() - start_time, lr )) save_checkpoint(model, optimizer, epoch, global_step, args) print('Elapsed time = %.2f HR' %((time.time() - start_full_time)/3600))
def main(args): train_loader, test_loader = make_data_loader(args) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) model = model_utils.make_model(args, do_flow=not args.no_flow, do_disp=not args.no_disp, do_seg=(args.do_seg or args.do_seg_distill)) print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0004) if args.loadmodel is not None: ckpt = torch.load(args.loadmodel) state_dict = ckpt['state_dict'] missing_keys, unexpected_keys = model.load_state_dict( model_utils.patch_model_state_dict(state_dict)) assert not unexpected_keys, 'Got unexpected keys: {}'.format( unexpected_keys) if missing_keys: for mk in missing_keys: assert mk.find( 'seg_decoder' ) >= 0, 'Only segmentation decoder can be initialized randomly.' print('==> A pre-trained model has been loaded.') start_epoch = 1 if args.auto_resume: # search for the latest saved checkpoint epoch_found = -1 for epoch in range(args.epochs + 1, 1, -1): ckpt_dir = model_utils.make_joint_checkpoint_name(args, epoch) ckpt_dir = os.path.join(args.savemodel, ckpt_dir) ckpt_path = os.path.join(ckpt_dir, 'model_{:04d}.pth'.format(epoch)) if os.path.exists(ckpt_path): epoch_found = epoch break if epoch_found > 0: ckpt = torch.load(ckpt_path) assert ckpt['epoch'] == epoch_found, [ckpt['epoch'], epoch_found] start_epoch = ckpt['epoch'] + 1 optimizer.load_state_dict(ckpt['optimizer']) model.load_state_dict(ckpt['state_dict']) print('==> Automatically resumed training from {}.'.format( ckpt_path)) else: if args.resume is not None: ckpt = torch.load(args.resume) start_epoch = ckpt['epoch'] + 1 optimizer.load_state_dict(ckpt['optimizer']) model.load_state_dict(ckpt['state_dict']) print('==> Manually resumed training from {}.'.format(args.resume)) cudnn.benchmark = True (flow_crit, flow_occ_crit ), flow_down_scales, flow_weights = model_utils.make_flow_criteria(args) (disp_crit, disp_occ_crit ), disp_down_scales, disp_weights = model_utils.make_disp_criteria(args) hard_seg_crit = model_utils.make_seg_criterion(args, hard_lab=True) soft_seg_crit = model_utils.make_seg_criterion(args, hard_lab=False) args.hard_seg_loss_weight *= float(disp_weights[0]) args.soft_seg_loss_weight *= float(disp_weights[0]) self_supervised_crit = make_self_supervised_loss( args, disp_downscales=disp_down_scales, disp_pyramid_weights=disp_weights, flow_downscales=flow_down_scales, flow_pyramid_weights=flow_weights).cuda() criteria = (disp_crit, disp_occ_crit, flow_crit, flow_occ_crit, hard_seg_crit, soft_seg_crit, self_supervised_crit) min_loss = 100000000000000000 min_epo = 0 min_err_pct = 10000 start_full_time = time.time() train_print_format = '{}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}'\ '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}' test_print_format = '{}\t{:d}\t{:d}\t{:.3f}\t{:.2f}\t{:.3f}\t{:.2f}\t{:.2f}\t{:.2f}'\ '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}' global_step = 0 for epoch in range(start_epoch, args.epochs + 1): total_train_loss = 0 total_err = 0 total_test_err_pct = 0 total_disp_occ_acc = 0 total_epe = 0 total_flow_occ_acc = 0 total_seg_acc = 0 lr = adjust_learning_rate(optimizer, epoch, len(train_loader)) ## training ## start_time = time.time() for batch_idx, batch_data in enumerate(train_loader): end = time.time() # (cur_im, nxt_im), (flow, flow_occ), (left_im, right_im), (disp, disp_occ, seg_im) = data # if args.seg_root_dir is None: # seg_im = None train_res = train(model, optimizer, batch_data, criteria, args) loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss, seg_loss, seg_distill_loss, ss_loss, ss_losses = train_res global_step += 1 if (batch_idx + 1) % args.print_freq == 0: print( train_print_format.format('Train', global_step, epoch, batch_idx, len(train_loader), loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss, seg_loss, seg_distill_loss, ss_loss, end - start_time, time.time() - start_time, lr)) for k, v in ss_losses.items(): print('{: <10}\t{:.3f}'.format(k, v)) sys.stdout.flush() start_time = time.time() total_train_loss += loss # should have had a validation set save_checkpoint(model, optimizer, epoch, global_step, args) print('Elapsed time = %.2f HR' % ((time.time() - start_full_time) / 3600))
def main(args): train_loader, test_loader = make_data_loader(args) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) model = model_utils.make_model(args, do_seg=args.do_seg) print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0004) if args.loadmodel is not None: ckpt = torch.load(args.loadmodel) state_dict = ckpt['state_dict'] model.load_state_dict(model_utils.patch_model_state_dict(state_dict)) print('==> A pre-trained checkpoint has been loaded {}'.format( args.loadmodel)) start_epoch = 1 if args.auto_resume: # search for the latest saved checkpoint epoch_found = -1 for epoch in range(args.epochs + 1, 1, -1): ckpt_path = model_utils.make_joint_checkpoint_name(args, epoch) ckpt_path = os.path.join(args.savemodel, ckpt_path) if os.path.exists(ckpt_path): epoch_found = epoch break if epoch_found > 0: ckpt = torch.load(ckpt_path) assert ckpt['epoch'] == epoch_found, [ckpt['epoch'], epoch_found] start_epoch = ckpt['epoch'] + 1 optimizer.load_state_dict(ckpt['optimizer']) model.load_state_dict(ckpt['state_dict']) print('==> Automatically resumed training from {}.'.format( ckpt_path)) cudnn.benchmark = True (flow_crit, flow_occ_crit ), flow_down_scales, flow_weights = model_utils.make_flow_criteria(args) (disp_crit, disp_occ_crit ), disp_down_scales, disp_weights = model_utils.make_disp_criteria(args) criteria = (disp_crit, disp_occ_crit, flow_crit, flow_occ_crit) min_loss = 100000000000000000 min_epo = 0 min_err_pct = 10000 start_full_time = time.time() train_print_format = '{}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}'\ '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t\t{:.6f}' test_print_format = '{}\t{:d}\t{:d}\t{:.3f}\t{:.2f}\t{:.3f}\t{:.2f}\t{:.2f}\t{:.2f}'\ '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}' global_step = 0 for epoch in range(start_epoch, args.epochs + 1): total_train_loss = 0 total_err = 0 total_test_err_pct = 0 total_disp_occ_acc = 0 total_epe = 0 total_flow_occ_acc = 0 total_seg_acc = 0 lr = adjust_learning_rate(optimizer, epoch, len(train_loader)) ## training ## start_time = time.time() for batch_idx, batch_data in enumerate(train_loader): end = time.time() train_res = train(model, optimizer, batch_data, criteria, args) loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss = train_res global_step += 1 if (batch_idx + 1) % args.print_freq == 0: print( train_print_format.format('Train', global_step, epoch, batch_idx, len(train_loader), loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss, end - start_time, time.time() - start_time, lr)) sys.stdout.flush() start_time = time.time() total_train_loss += loss save_checkpoint(model, optimizer, epoch, global_step, args) print('Elapsed time = %.2f HR' % ((time.time() - start_full_time) / 3600))