def forward(net, BG, FG, dp, loss): rgb, fg_gt, bg_gt, gt, tri, small_tri = dp with torch.no_grad(): # rgb, fg_gt, bg_gt is in [-1, 1] # gt is in [0, 1] # tri is 1-channel trimap in {0, 1, 2}, 0 BG 1 UN 2 FG # preprocess which does not need gradients rgb = rgb.cuda().float().clamp(-1., 1.) # [b, 3, h, w] tri = tri.cuda().float().clamp(0., 2.) # [b, 1, h, w] gt = gt.cuda().float().clamp(0., 1.) # [b, 1, h, w] fg_gt = fg_gt.cuda().float().clamp(-1., 1.) # [b, 3, h, w] bg_gt = bg_gt.cuda().float().clamp(-1., 1.) # [b, 3, h, w] small_tri = small_tri.cuda().float().clamp(0., 2.) mask = torch.eq(tri, 1.) # bg and fg should be a float32 tensor in [0, 1] f_u_mask = (tri > 0.01).float() b_u_mask = (tri < 1.99).float() small_fumask = (small_tri > 0.01).float() small_bumask = (small_tri < 1.99).float() # bg_pred and fg_pred are already in [-1, 1] _, bg_pred, _ = BG(rgb, f_u_mask, small_fumask) bg_pred_ = torch.where(mask, bg_pred, rgb) #mask * bg_pred + (1 - mask) * rgb _, fg_pred, _ = FG(rgb, b_u_mask, small_bumask, bg_img=bg_pred_) fg_ = ((fg_pred + 1.0) / 2.0).clamp(0., 1.) bg_ = ((bg_pred + 1.0) / 2.0).clamp(0., 1.) rgb_ = ((rgb + 1.0) / 2.0).clamp(0., 1.) input_x = torch.cat([rgb, fg_pred, bg_pred, tri - 1.], axis=1) # network forward pred = net(input_x) alpha = torch.where(mask, pred, tri / 2.0) # composition comp = fg_ * alpha + bg_ * (1. - alpha) # loss calculation valid_mask = mask.float() loss['L_alpha'] = L.L1_mask(alpha, gt, valid_mask) loss['L_comp'] = L.L1_mask(comp, rgb_, valid_mask) loss['L_grad'] = L.L1_grad(alpha, gt, valid_mask) loss['L_total'] = (loss['L_alpha'] + loss['L_comp']) * 0.5 + loss['L_grad'] return [ fg_gt, bg_gt, gt, tri, fg_, bg_, rgb_, alpha, comp, valid_mask, loss ]
def main(): image_size = [args.IMAGE_SHAPE[0], args.IMAGE_SHAPE[1]] if args.model_name is not None: model_save_dir = './snapshots/' + args.model_name + '/ckpt/' sample_dir = './snapshots/' + args.model_name + '/images/' log_dir = './logs/' + args.model_name else: model_save_dir = os.path.join(args.save_dir, 'ckpt') sample_dir = os.path.join(args.save_dir, 'images') log_dir = args.log_dir if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) with open(os.path.join(log_dir, 'config.yml'), 'w') as f: yaml.dump(vars(args), f) writer = SummaryWriter(log_dir=log_dir) torch.manual_seed(7777777) if not args.CPU: torch.cuda.manual_seed(7777777) flow_resnet = resnet_models.Flow_Branch_Multi(input_chanels=66, NoLabels=4) saved_state_dict = torch.load(args.RESNET_PRETRAIN_MODEL) for i in saved_state_dict: if 'conv1.' in i[:7]: conv1_weight = saved_state_dict[i] conv1_weight_mean = torch.mean(conv1_weight, dim=1, keepdim=True) conv1_weight_new = (conv1_weight_mean / 66.0).repeat(1, 66, 1, 1) saved_state_dict[i] = conv1_weight_new flow_resnet.load_state_dict(saved_state_dict, strict=False) flow_resnet = nn.DataParallel(flow_resnet).cuda() flow_resnet.train() optimizer = optim.SGD([{ 'params': get_1x_lr_params(flow_resnet.module), 'lr': args.LR }, { 'params': get_10x_lr_params(flow_resnet.module), 'lr': 10 * args.LR }], lr=args.LR, momentum=0.9, weight_decay=args.WEIGHT_DECAY) train_dataset = FlowSeq(args) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.n_threads) if args.resume: if args.PRETRAINED_MODEL is not None: resume_iter = load_ckpt(args.PRETRAINED_MODEL, [('model', flow_resnet)], [('optimizer', optimizer)], strict=True) print('Model Resume from', resume_iter, 'iter') else: print('Cannot load Pretrained Model') return if args.PRETRAINED: if args.PRETRAINED_MODEL is not None: resume_iter = load_ckpt(args.PRETRAINED_MODEL, [('model', flow_resnet)], strict=True) print('Model Resume from', resume_iter, 'iter') train_iterator = iter(train_loader) loss = {} start_iter = 0 if not args.resume else resume_iter for i in tqdm(range(start_iter, args.max_iter)): try: flow_mask_cat, flow_masked, gt_flow, mask = next(train_iterator) except: print('Loader Restart') train_iterator = iter(train_loader) flow_mask_cat, flow_masked, gt_flow, mask = next(train_iterator) input_x = flow_mask_cat.cuda() gt_flow = gt_flow.cuda() mask = mask.cuda() flow_masked = flow_masked.cuda() flow1x = flow_resnet(input_x) f_res = flow1x[:, :2, :, :] r_res = flow1x[:, 2:, :, :] # fake_flow_f = f_res * mask[:,10:12,:,:] + flow_masked[:,10:12,:,:] * (1. - mask[:,10:12,:,:]) # fake_flow_r = r_res * mask[:,32:34,:,:] + flow_masked[:,32:34,:,:] * (1. - mask[:,32:34,:,:]) loss['1x_recon'] = L.L1_mask(f_res, gt_flow[:, :2, :, :], mask[:, 10:12, :, :]) loss['1x_recon'] += L.L1_mask(r_res, gt_flow[:, 2:, ...], mask[:, 32:34, ...]) loss['f_recon_hard'], new_mask = L.L1_mask_hard_mining( f_res, gt_flow[:, :2, :, :], mask[:, 10:11, :, :]) loss['r_recon_hard'], new_mask = L.L1_mask_hard_mining( r_res, gt_flow[:, 2:, ...], mask[:, 32:33, ...]) loss_total = loss['1x_recon'] + args.LAMBDA_HARD * ( loss['f_recon_hard'] + loss['r_recon_hard']) if i % args.NUM_ITERS_DECAY == 0: adjust_learning_rate(optimizer, i, args.lr_decay_steps) print('LR has been changed') optimizer.zero_grad() loss_total.backward() optimizer.step() if i % args.PRINT_EVERY == 0: print('=========================================================') print(args.model_name, "Rank[{}] Iter [{}/{}]".format(0, i + 1, args.max_iter)) print('=========================================================') print_loss_dict(loss) write_loss_dict(loss, writer, i) if (i + 1) % args.MODEL_SAVE_STEP == 0: save_ckpt(os.path.join(model_save_dir, 'DFI_%d.pth' % i), [('model', flow_resnet)], [('optimizer', optimizer)], i) print('Model has been saved at %d Iters' % i) writer.close()