def forward(net, BG, FG, dp, loss):
    rgb, fg_gt, bg_gt, gt, tri, small_tri = dp
    with torch.no_grad():
        # rgb, fg_gt, bg_gt is in [-1, 1]
        # gt is in [0, 1]
        # tri is 1-channel trimap in {0, 1, 2}, 0 BG 1 UN 2 FG
        # preprocess which does not need gradients
        rgb = rgb.cuda().float().clamp(-1., 1.)  # [b, 3, h, w]
        tri = tri.cuda().float().clamp(0., 2.)  # [b, 1, h, w]
        gt = gt.cuda().float().clamp(0., 1.)  # [b, 1, h, w]
        fg_gt = fg_gt.cuda().float().clamp(-1., 1.)  # [b, 3, h, w]
        bg_gt = bg_gt.cuda().float().clamp(-1., 1.)  # [b, 3, h, w]
        small_tri = small_tri.cuda().float().clamp(0., 2.)
        mask = torch.eq(tri, 1.)

        # bg and fg should be a float32 tensor in [0, 1]
        f_u_mask = (tri > 0.01).float()
        b_u_mask = (tri < 1.99).float()
        small_fumask = (small_tri > 0.01).float()
        small_bumask = (small_tri < 1.99).float()
        # bg_pred and fg_pred are already in [-1, 1]
        _, bg_pred, _ = BG(rgb, f_u_mask, small_fumask)
        bg_pred_ = torch.where(mask, bg_pred,
                               rgb)  #mask * bg_pred + (1 - mask) * rgb
        _, fg_pred, _ = FG(rgb, b_u_mask, small_bumask, bg_img=bg_pred_)
        fg_ = ((fg_pred + 1.0) / 2.0).clamp(0., 1.)
        bg_ = ((bg_pred + 1.0) / 2.0).clamp(0., 1.)
        rgb_ = ((rgb + 1.0) / 2.0).clamp(0., 1.)

        input_x = torch.cat([rgb, fg_pred, bg_pred, tri - 1.], axis=1)

    # network forward
    pred = net(input_x)
    alpha = torch.where(mask, pred, tri / 2.0)
    # composition
    comp = fg_ * alpha + bg_ * (1. - alpha)

    # loss calculation
    valid_mask = mask.float()
    loss['L_alpha'] = L.L1_mask(alpha, gt, valid_mask)
    loss['L_comp'] = L.L1_mask(comp, rgb_, valid_mask)
    loss['L_grad'] = L.L1_grad(alpha, gt, valid_mask)
    loss['L_total'] = (loss['L_alpha'] + loss['L_comp']) * 0.5 + loss['L_grad']

    return [
        fg_gt, bg_gt, gt, tri, fg_, bg_, rgb_, alpha, comp, valid_mask, loss
    ]
Beispiel #2
0
def main():

    image_size = [args.IMAGE_SHAPE[0], args.IMAGE_SHAPE[1]]

    if args.model_name is not None:
        model_save_dir = './snapshots/' + args.model_name + '/ckpt/'
        sample_dir = './snapshots/' + args.model_name + '/images/'
        log_dir = './logs/' + args.model_name
    else:
        model_save_dir = os.path.join(args.save_dir, 'ckpt')
        sample_dir = os.path.join(args.save_dir, 'images')
        log_dir = args.log_dir

    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    with open(os.path.join(log_dir, 'config.yml'), 'w') as f:
        yaml.dump(vars(args), f)

    writer = SummaryWriter(log_dir=log_dir)

    torch.manual_seed(7777777)
    if not args.CPU:
        torch.cuda.manual_seed(7777777)

    flow_resnet = resnet_models.Flow_Branch_Multi(input_chanels=66, NoLabels=4)
    saved_state_dict = torch.load(args.RESNET_PRETRAIN_MODEL)
    for i in saved_state_dict:
        if 'conv1.' in i[:7]:
            conv1_weight = saved_state_dict[i]
            conv1_weight_mean = torch.mean(conv1_weight, dim=1, keepdim=True)
            conv1_weight_new = (conv1_weight_mean / 66.0).repeat(1, 66, 1, 1)
            saved_state_dict[i] = conv1_weight_new
    flow_resnet.load_state_dict(saved_state_dict, strict=False)
    flow_resnet = nn.DataParallel(flow_resnet).cuda()
    flow_resnet.train()

    optimizer = optim.SGD([{
        'params': get_1x_lr_params(flow_resnet.module),
        'lr': args.LR
    }, {
        'params': get_10x_lr_params(flow_resnet.module),
        'lr': 10 * args.LR
    }],
                          lr=args.LR,
                          momentum=0.9,
                          weight_decay=args.WEIGHT_DECAY)

    train_dataset = FlowSeq(args)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=args.n_threads)

    if args.resume:
        if args.PRETRAINED_MODEL is not None:
            resume_iter = load_ckpt(args.PRETRAINED_MODEL,
                                    [('model', flow_resnet)],
                                    [('optimizer', optimizer)],
                                    strict=True)
            print('Model Resume from', resume_iter, 'iter')
        else:
            print('Cannot load Pretrained Model')
            return

    if args.PRETRAINED:
        if args.PRETRAINED_MODEL is not None:
            resume_iter = load_ckpt(args.PRETRAINED_MODEL,
                                    [('model', flow_resnet)],
                                    strict=True)
            print('Model Resume from', resume_iter, 'iter')

    train_iterator = iter(train_loader)

    loss = {}

    start_iter = 0 if not args.resume else resume_iter

    for i in tqdm(range(start_iter, args.max_iter)):
        try:
            flow_mask_cat, flow_masked, gt_flow, mask = next(train_iterator)
        except:
            print('Loader Restart')
            train_iterator = iter(train_loader)
            flow_mask_cat, flow_masked, gt_flow, mask = next(train_iterator)

        input_x = flow_mask_cat.cuda()
        gt_flow = gt_flow.cuda()
        mask = mask.cuda()
        flow_masked = flow_masked.cuda()

        flow1x = flow_resnet(input_x)
        f_res = flow1x[:, :2, :, :]
        r_res = flow1x[:, 2:, :, :]

        # fake_flow_f = f_res * mask[:,10:12,:,:] + flow_masked[:,10:12,:,:] * (1. - mask[:,10:12,:,:])
        # fake_flow_r = r_res * mask[:,32:34,:,:] + flow_masked[:,32:34,:,:] * (1. - mask[:,32:34,:,:])

        loss['1x_recon'] = L.L1_mask(f_res, gt_flow[:, :2, :, :],
                                     mask[:, 10:12, :, :])
        loss['1x_recon'] += L.L1_mask(r_res, gt_flow[:, 2:, ...],
                                      mask[:, 32:34, ...])
        loss['f_recon_hard'], new_mask = L.L1_mask_hard_mining(
            f_res, gt_flow[:, :2, :, :], mask[:, 10:11, :, :])
        loss['r_recon_hard'], new_mask = L.L1_mask_hard_mining(
            r_res, gt_flow[:, 2:, ...], mask[:, 32:33, ...])

        loss_total = loss['1x_recon'] + args.LAMBDA_HARD * (
            loss['f_recon_hard'] + loss['r_recon_hard'])

        if i % args.NUM_ITERS_DECAY == 0:
            adjust_learning_rate(optimizer, i, args.lr_decay_steps)
            print('LR has been changed')

        optimizer.zero_grad()
        loss_total.backward()
        optimizer.step()

        if i % args.PRINT_EVERY == 0:
            print('=========================================================')
            print(args.model_name,
                  "Rank[{}] Iter [{}/{}]".format(0, i + 1, args.max_iter))
            print('=========================================================')
            print_loss_dict(loss)
            write_loss_dict(loss, writer, i)

        if (i + 1) % args.MODEL_SAVE_STEP == 0:
            save_ckpt(os.path.join(model_save_dir, 'DFI_%d.pth' % i),
                      [('model', flow_resnet)], [('optimizer', optimizer)], i)
            print('Model has been saved at %d Iters' % i)

    writer.close()