def test_refine_stage(args): torch.manual_seed(777) torch.cuda.manual_seed(777) eval_dataset = FlowRefine.FlowSeq(args, isTest=True) eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.n_threads) if args.ResNet101: dfc_resnet101 = resnet_models.Flow_Branch(66, 4) dfc_resnet = nn.DataParallel(dfc_resnet101).cuda() else: dfc_resnet50 = resnet_models.Flow_Branch_Multi(input_chanels=66, NoLabels=4) dfc_resnet = nn.DataParallel(dfc_resnet50).cuda() dfc_resnet.eval() resume_iter = load_ckpt(args.PRETRAINED_MODEL, [('model', dfc_resnet)], strict=True) print('Load Pretrained Model from', args.PRETRAINED_MODEL) #task_bar = ProgressBar(eval_dataset.__len__()) for i, item in tqdm(enumerate(eval_dataloader), total=len(eval_dataset)): with torch.no_grad(): input_x = item[0].cuda() flow_masked = item[1].cuda() gt_flow = item[2].cuda() mask = item[3].cuda() output_dir = item[4][0] res_flow = dfc_resnet(input_x) res_flow_f = res_flow[:, :2, :, :] res_flow_r = res_flow[:, 2:, :, :] res_complete_f = res_flow_f * mask[:, 10:11, :, :] + flow_masked[:, 10:12, :, :] * (1. - mask[:, 10:11, :, :]) res_complete_r = res_flow_r * mask[:,32:34,:,:] + flow_masked[:,32:34,:,:] * (1. - mask[:,32:34,:,:]) output_dir_split = output_dir.split(',') output_file_f = os.path.join(args.output_root, output_dir_split[0]) output_file_r = os.path.join(args.output_root, output_dir_split[1]) output_basedir = os.path.dirname(output_file_f) if not os.path.exists(output_basedir): os.makedirs(output_basedir) res_save_f = res_complete_f[0].permute(1, 2, 0).contiguous().cpu().data.numpy() cvb.write_flow(res_save_f, output_file_f) res_save_r = res_complete_r[0].permute(1, 2, 0).contiguous().cpu().data.numpy() cvb.write_flow(res_save_r, output_file_r) #task_bar.update() sys.stdout.write('\n') dfc_resnet = None torch.cuda.empty_cache() print('Refined Results Saved in', args.output_root)
def test_initial_stage(args): torch.manual_seed(777) torch.cuda.manual_seed(777) args.INITIAL_HOLE = True args.get_mask = True eval_dataset = FlowInitial.FlowSeq(args, isTest=True) eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.n_threads) if args.ResNet101: dfc_resnet101 = resnet_models.Flow_Branch(33, 2) dfc_resnet = nn.DataParallel(dfc_resnet101).cuda() else: dfc_resnet50 = resnet_models.Flow_Branch_Multi(input_chanels=33, NoLabels=2) dfc_resnet = nn.DataParallel(dfc_resnet50).cuda() dfc_resnet.eval() resume_iter = load_ckpt(args.PRETRAINED_MODEL, [('model', dfc_resnet)], strict=True) print('Load Pretrained Model from', args.PRETRAINED_MODEL) task_bar = ProgressBar(eval_dataset.__len__()) for i, item in enumerate(eval_dataloader): with torch.no_grad(): input_x = item[0].cuda() flow_masked = item[1].cuda() mask = item[3].cuda() output_dir = item[4][0] res_flow = dfc_resnet(input_x) res_complete = res_flow * mask[:, 10: 11, :, :] + flow_masked[:, 10:12, :, :] * ( 1. - mask[:, 10:11, :, :]) output_dir_split = output_dir.split(',') output_file = os.path.join(args.output_root, output_dir_split[0]) output_basedir = os.path.dirname(output_file) if not os.path.exists(output_basedir): os.makedirs(output_basedir) res_save = res_complete[0].permute( 1, 2, 0).contiguous().cpu().data.numpy() cvb.write_flow(res_save, output_file) task_bar.update() print('Initial Results Saved in', args.output_root)
def main(): image_size = [args.IMAGE_SHAPE[0], args.IMAGE_SHAPE[1]] if args.model_name is not None: model_save_dir = './snapshots/' + args.model_name + '/ckpt/' sample_dir = './snapshots/' + args.model_name + '/images/' log_dir = './logs/' + args.model_name else: model_save_dir = os.path.join(args.save_dir, 'ckpt') sample_dir = os.path.join(args.save_dir, 'images') log_dir = args.log_dir if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) with open(os.path.join(log_dir, 'config.yml'), 'w') as f: yaml.dump(vars(args), f) writer = SummaryWriter(log_dir=log_dir) torch.manual_seed(7777777) if not args.CPU: torch.cuda.manual_seed(7777777) flow_resnet = resnet_models.Flow_Branch_Multi(input_chanels=66, NoLabels=4) saved_state_dict = torch.load(args.RESNET_PRETRAIN_MODEL) for i in saved_state_dict: if 'conv1.' in i[:7]: conv1_weight = saved_state_dict[i] conv1_weight_mean = torch.mean(conv1_weight, dim=1, keepdim=True) conv1_weight_new = (conv1_weight_mean / 66.0).repeat(1, 66, 1, 1) saved_state_dict[i] = conv1_weight_new flow_resnet.load_state_dict(saved_state_dict, strict=False) flow_resnet = nn.DataParallel(flow_resnet).cuda() flow_resnet.train() optimizer = optim.SGD([{ 'params': get_1x_lr_params(flow_resnet.module), 'lr': args.LR }, { 'params': get_10x_lr_params(flow_resnet.module), 'lr': 10 * args.LR }], lr=args.LR, momentum=0.9, weight_decay=args.WEIGHT_DECAY) train_dataset = FlowSeq(args) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.n_threads) if args.resume: if args.PRETRAINED_MODEL is not None: resume_iter = load_ckpt(args.PRETRAINED_MODEL, [('model', flow_resnet)], [('optimizer', optimizer)], strict=True) print('Model Resume from', resume_iter, 'iter') else: print('Cannot load Pretrained Model') return if args.PRETRAINED: if args.PRETRAINED_MODEL is not None: resume_iter = load_ckpt(args.PRETRAINED_MODEL, [('model', flow_resnet)], strict=True) print('Model Resume from', resume_iter, 'iter') train_iterator = iter(train_loader) loss = {} start_iter = 0 if not args.resume else resume_iter for i in tqdm(range(start_iter, args.max_iter)): try: flow_mask_cat, flow_masked, gt_flow, mask = next(train_iterator) except: print('Loader Restart') train_iterator = iter(train_loader) flow_mask_cat, flow_masked, gt_flow, mask = next(train_iterator) input_x = flow_mask_cat.cuda() gt_flow = gt_flow.cuda() mask = mask.cuda() flow_masked = flow_masked.cuda() flow1x = flow_resnet(input_x) f_res = flow1x[:, :2, :, :] r_res = flow1x[:, 2:, :, :] # fake_flow_f = f_res * mask[:,10:12,:,:] + flow_masked[:,10:12,:,:] * (1. - mask[:,10:12,:,:]) # fake_flow_r = r_res * mask[:,32:34,:,:] + flow_masked[:,32:34,:,:] * (1. - mask[:,32:34,:,:]) loss['1x_recon'] = L.L1_mask(f_res, gt_flow[:, :2, :, :], mask[:, 10:12, :, :]) loss['1x_recon'] += L.L1_mask(r_res, gt_flow[:, 2:, ...], mask[:, 32:34, ...]) loss['f_recon_hard'], new_mask = L.L1_mask_hard_mining( f_res, gt_flow[:, :2, :, :], mask[:, 10:11, :, :]) loss['r_recon_hard'], new_mask = L.L1_mask_hard_mining( r_res, gt_flow[:, 2:, ...], mask[:, 32:33, ...]) loss_total = loss['1x_recon'] + args.LAMBDA_HARD * ( loss['f_recon_hard'] + loss['r_recon_hard']) if i % args.NUM_ITERS_DECAY == 0: adjust_learning_rate(optimizer, i, args.lr_decay_steps) print('LR has been changed') optimizer.zero_grad() loss_total.backward() optimizer.step() if i % args.PRINT_EVERY == 0: print('=========================================================') print(args.model_name, "Rank[{}] Iter [{}/{}]".format(0, i + 1, args.max_iter)) print('=========================================================') print_loss_dict(loss) write_loss_dict(loss, writer, i) if (i + 1) % args.MODEL_SAVE_STEP == 0: save_ckpt(os.path.join(model_save_dir, 'DFI_%d.pth' % i), [('model', flow_resnet)], [('optimizer', optimizer)], i) print('Model has been saved at %d Iters' % i) writer.close()
def flow_completion(self): if self.i == -1: data_list_dir = os.path.join(self.args.dataset_root, 'data') os.makedirs(data_list_dir, exist_ok=True) initial_data_list = os.path.join(data_list_dir, 'initial_test_list.txt') print('Generate datalist for initial step') data_list.gen_flow_initial_test_mask_list( flow_root=self.args.DATA_ROOT, output_txt_path=initial_data_list) self.args.EVAL_LIST = os.path.join(data_list_dir, 'initial_test_list.txt') self.args.output_root = os.path.join(self.args.dataset_root, 'Flow_res', 'initial_res') self.args.PRETRAINED_MODEL = self.args.PRETRAINED_MODEL_1 if self.args.img_size is not None: self.args.IMAGE_SHAPE = [ self.args.img_size[0] // 2, self.args.img_size[1] // 2 ] self.args.RES_SHAPE = self.args.IMAGE_SHAPE print('Flow Completion in First Step') self.args.MASK_ROOT = self.args.mask_root eval_dataset = FlowInitial.FlowSeq(self.args, isTest=True) self.flow_refinement_dataloader = iter( DataLoader(eval_dataset, batch_size=self.settings.batch_size, shuffle=False, drop_last=False, num_workers=self.args.n_threads)) if self.args.ResNet101: dfc_resnet101 = resnet_models.Flow_Branch(33, 2) self.dfc_resnet = nn.DataParallel(dfc_resnet101).to( self.args.device) else: dfc_resnet50 = resnet_models.Flow_Branch_Multi( input_chanels=33, NoLabels=2) self.dfc_resnet = nn.DataParallel(dfc_resnet50).to( self.args.device) self.dfc_resnet.eval() io.load_ckpt(self.args.PRETRAINED_MODEL, [('model', self.dfc_resnet)], strict=True) print('Load Pretrained Model from', self.args.PRETRAINED_MODEL) self.i += 1 complete = False with torch.no_grad(): try: item = next(self.flow_refinement_dataloader) input_x = item[0].to(self.args.device) flow_masked = item[1].to(self.args.device) mask = item[3].to(self.args.device) output_dir = item[4][0] res_flow = self.dfc_resnet(input_x) res_complete = res_flow * mask[:, 10: 11, :, :] + flow_masked[:, 10:12, :, :] * ( 1. - mask[:, 10:11, :, :]) output_dir_split = output_dir.split(',') output_file = os.path.join(self.args.output_root, output_dir_split[0]) output_basedir = os.path.dirname(output_file) if not os.path.exists(output_basedir): os.makedirs(output_basedir) res_save = res_complete[0].permute( 1, 2, 0).contiguous().cpu().data.numpy() cvb.write_flow(res_save, output_file) except StopIteration: complete = True if self.i == len(self.flow_refinement_dataloader) - 1 or complete: self.args.flow_root = self.args.output_root del self.flow_refinement_dataloader, self.dfc_resnet self.i = -1 self.state += 1