예제 #1
0
파일: predict.py 프로젝트: 12-wu/-
def main(args):

    args.model = os.path.expanduser(args.model)
    args.config = os.path.expanduser(args.config)
    args.input_img = os.path.expanduser(args.input_img)
    args.output_img = os.path.expanduser(args.output_img)

    # =============================================
    # Load model
    # =============================================
    with open(args.config, 'r') as f:
        config = json.load(f)
    mpv = torch.tensor(config['mpv']).view(1, 3, 1, 1)
    model = CompletionNetwork()
    if config['data_parallel']:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model, map_location='cpu'))

    # =============================================
    # Predict
    # =============================================
    # convert img to tensor
    img = Image.open(args.input_img)
    img = transforms.Resize(args.img_size)(img)
    img = transforms.RandomCrop((args.img_size, args.img_size))(img)
    x = transforms.ToTensor()(img)
    x = torch.unsqueeze(x, dim=0)

    # create mask
    mask = gen_input_mask(
        shape=(1, 1, x.shape[2], x.shape[3]),
        hole_size=(
            (args.hole_min_w, args.hole_max_w),
            (args.hole_min_h, args.hole_max_h),
        ),
        max_holes=args.max_holes,
    )

    # inpaint
    with torch.no_grad():
        x_mask = x - x * mask + mpv * mask
        input = torch.cat((x_mask, mask), dim=1)
        output = model(input)
        inpainted = poisson_blend(x, output, mask)
        imgs = torch.cat((x, x_mask, inpainted), dim=0)
        save_image(imgs, args.output_img, nrow=3)
    print('output img was saved as %s.' % args.output_img)
def main(args):

    args.model = os.path.expanduser(args.model)
    args.config = os.path.expanduser(args.config)
    args.input_img = os.path.expanduser(args.input_img)
    args.output_img = os.path.expanduser(args.output_img)

    # =============================================
    # Load model
    # =============================================
    with open(args.config, 'r') as f:
        config = json.load(f)
    mpv = config['mean_pv']
    model = CompletionNetwork()
    model.load_state_dict(torch.load(args.model, map_location='cpu'))

    # =============================================
    # Predict
    # =============================================
    # convert img to tensor
    img = Image.open(args.input_img)
    img = transforms.Resize(args.img_size)(img)
    img = transforms.RandomCrop((args.img_size, args.img_size))(img)
    x = transforms.ToTensor()(img)
    x = torch.unsqueeze(x, dim=0)

    # create mask
    msk = gen_input_mask(
        shape=x.shape,
        hole_size=(
            (args.hole_min_w, args.hole_max_w),
            (args.hole_min_h, args.hole_max_h),
        ),
        max_holes=args.max_holes,
    )

    # inpaint
    with torch.no_grad():
        input = x - x * msk + mpv * msk
        output = model(input)
        inpainted = poisson_blend(input, output, msk)
        imgs = torch.cat((x, input, inpainted), dim=-1)
        imgs = save_image(imgs, args.output_img, nrow=3)
    print('output img was saved as %s.' % args.output_img)
예제 #3
0
def main(args):

    # ================================================
    # Preparation
    # ================================================
    args.data_dir = os.path.expanduser(args.data_dir)
    args.result_dir = os.path.expanduser(args.result_dir)
    if args.init_model_cn != None:
        args.init_model_cn = os.path.expanduser(args.init_model_cn)
    if args.init_model_cd != None:
        args.init_model_cd = os.path.expanduser(args.init_model_cd)
    if torch.cuda.is_available() == False:
        raise Exception('At least one gpu must be available.')
    else:
        gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if os.path.exists(args.result_dir) == False:
        os.makedirs(args.result_dir)
    for s in ['phase_1', 'phase_2', 'phase_3']:
        if os.path.exists(os.path.join(args.result_dir, s)) == False:
            os.makedirs(os.path.join(args.result_dir, s))

    # dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'),
                              trnsfm,
                              recursive_search=args.recursive_search)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'),
                             trnsfm,
                             recursive_search=args.recursive_search)
    train_loader = DataLoader(train_dset,
                              batch_size=(args.bsize // args.bdivs),
                              shuffle=True)

    # compute mean pixel value of training dataset
    mpv = np.zeros(shape=(3, ))
    if args.mpv == None:
        pbar = tqdm(total=len(train_dset.imgpaths),
                    desc='computing mean pixel value for training dataset...')
        for imgpath in train_dset.imgpaths:
            img = Image.open(imgpath)
            x = np.array(img, dtype=np.float32) / 255.
            mpv += x.mean(axis=(0, 1))
            pbar.update()
        mpv /= len(train_dset.imgpaths)
        pbar.close()
    else:
        mpv = np.array(args.mpv)

    # save training config
    mpv_json = []
    for i in range(3):
        mpv_json.append(float(mpv[i]))  # convert to json serializable type
    args_dict = vars(args)
    args_dict['mpv'] = mpv_json
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # make mpv & alpha tensor
    mpv = torch.tensor(mpv.astype(np.float32).reshape(1, 3, 1, 1)).to(gpu)
    alpha = torch.tensor(args.alpha).to(gpu)

    # ================================================
    # Training Phase 1
    # ================================================
    model_cn = CompletionNetwork()
    if args.data_parallel:
        model_cn = DataParallel(model_cn)
    if args.init_model_cn != None:
        model_cn.load_state_dict(
            torch.load(args.init_model_cn, map_location='cpu'))
    if args.optimizer == 'adadelta':
        opt_cn = Adadelta(model_cn.parameters())
    else:
        opt_cn = Adam(model_cn.parameters())
    model_cn = model_cn.to(gpu)

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:

            # forward
            x = x.to(gpu)
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=gen_hole_area(
                    (args.ld_input_size, args.ld_input_size),
                    (x.shape[3], x.shape[2])),
                max_holes=args.max_holes,
            ).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input = torch.cat((x_mask, mask), dim=1)
            output = model_cn(input)
            loss = completion_network_loss(x, output, mask)

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_1 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_1',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_1',
                            'model_cn_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                # terminate
                if pbar.n >= args.steps_1:
                    break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    model_cd = ContextDiscriminator(
        local_input_shape=(3, args.ld_input_size, args.ld_input_size),
        global_input_shape=(3, args.cn_input_size, args.cn_input_size),
        arc=args.arc,
    )
    if args.data_parallel:
        model_cd = DataParallel(model_cd)
    if args.init_model_cd != None:
        model_cd.load_state_dict(
            torch.load(args.init_model_cd, map_location='cpu'))
    if args.optimizer == 'adadelta':
        opt_cd = Adadelta(model_cd.parameters())
    else:
        opt_cd = Adam(model_cd.parameters())
    model_cd = model_cd.to(gpu)
    bceloss = BCELoss()

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:

            # fake forward
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd(
                (input_ld_fake.to(gpu), input_gd_fake.to(gpu)))
            loss_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_real = bceloss(output_real, real)

            # reduce
            loss = (loss_fake + loss_real) / 2.

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()
                # update progbar
                pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_2 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_2',
                                               'step%d.png' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_2',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_2:
                    break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:

            # forward model_cd
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)

            # fake forward
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, input_gd_fake))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()

            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()

            # forward model_cn
            loss_cn_1 = completion_network_loss(x, output_cn, mask)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, (input_gd_fake)))
            loss_cn_2 = bceloss(output_fake, real)

            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.

            # backward model_cn
            loss_cn.backward()

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description(
                    'phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()
                # test
                if pbar.n % args.snaperiod_3 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cn_step%d' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_3:
                    break
    pbar.close()
def train_DG(n, steps):
    cnt_bdivs = 0
    pbar = tqdm(total=steps)
    while pbar.n < steps:
        for street_img, mask_poeple, mask, left_top in train_loader:

            # fake forward
            street_img = street_img.to(gpu)
            mask_poeple = mask_poeple.to(gpu)
            mask = mask.to(gpu)

            fake = torch.zeros((len(street_img), 1)).to(gpu)
            input_cn = torch.cat((street_img, mask_poeple), dim=1)
            output_cn = model_cn(input_cn)

            input_gd_fake = output_cn.detach()
            output_fake = model_cd(input_gd_fake.to(gpu))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            real = torch.ones((len(street_img), 1)).to(gpu)
            input_gd_real = street_img
            output_real = model_cd(input_gd_real.to(gpu))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()

            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()

            # forward model_cn
            #loss_cn_1 = torch.nn.functional.mse_loss(output_cn, street_img)
            loss_cn_1 = completion_network_loss_P(street_img,
                                                  mask_poeple,
                                                  mask,
                                                  output_cn,
                                                  left_top,
                                                  height=128,
                                                  width=64)
            input_gd_fake = output_cn
            output_fake = model_cd(input_gd_fake.to(gpu))
            loss_cn_2 = bceloss(output_fake, real)

            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.

            # backward model_cn
            loss_cn.backward()

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description(
                    '%d | phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (n, loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()
                # test
                if pbar.n % args.snaperiod_3 == 0:
                    with torch.no_grad():
                        x1, x2, x3 = sample_random_batch(test_dset,
                                                         batch_size=3)
                        x1 = x1.to(gpu)
                        x2 = x2.to(gpu)
                        x3 = (x3).le(0.5).to(torch.uint8) * 255

                        input = torch.cat((x1, x2), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(output, x1, x3)
                        imgs = torch.cat((x1.cpu(), x2.cpu(), output.cpu(),
                                          completed.cpu()),
                                         dim=2)

                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               '%d_step%d.png' % (n, pbar.n))

                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            '%d_model_cn_step%d' % (n, pbar.n))
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            '%d_model_cd_step%d' % (n, pbar.n))
                        save_image(imgs, imgpath, nrow=len(x1))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= steps:
                    break
    pbar.close()
def train_G(n, steps):
    cnt_bdivs = 0
    pbar = tqdm(total=steps)
    while pbar.n < steps:
        for street_img, mask_poeple, mask, left_top in train_loader:
            street_img = street_img.to(gpu)
            mask_poeple = mask_poeple.to(gpu)
            mask = mask.to(gpu)

            #print(street_img.shape)
            #print(mask_poeple.shape)
            #print(mask.shape)

            input = torch.cat((street_img, mask_poeple), dim=1)
            output = model_cn(input)
            #loss = completion_network_loss(x, output, mask)
            #loss = torch.nn.functional.mse_loss(output, street_img)

            loss = completion_network_loss_P(street_img,
                                             mask_poeple,
                                             mask,
                                             output,
                                             left_top,
                                             height=128,
                                             width=64)

            # backward
            loss.backward()
            cnt_bdivs += 1

            #print(output.shape)

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description('%d | phase 1 | train loss: %.5f' %
                                     (n, loss.cpu()))
                pbar.update()
                if pbar.n % args.snaperiod_1 == 0:
                    #if 1:
                    with torch.no_grad():
                        x1, x2, x3 = sample_random_batch(test_dset,
                                                         batch_size=3)
                        x1 = x1.to(gpu)
                        x2 = x2.to(gpu)
                        x3 = (x3).le(0.5).to(torch.uint8) * 255

                        input = torch.cat((x1, x2), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(output, x1, x3)
                        imgs = torch.cat((x1.cpu(), x2.cpu(), output.cpu(),
                                          completed.cpu()),
                                         dim=2)
                        imgpath = os.path.join(args.result_dir, 'phase_1',
                                               '%d_step%d.png' % (n, pbar.n))
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_1',
                            '%d_model_cn_step%d' % (n, pbar.n))
                        save_image(imgs, imgpath, nrow=len(x1))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                # terminate
                if pbar.n >= steps:
                    break
    pbar.close()
예제 #6
0
def main(args):

    # ================================================
    # Preparation
    # ================================================
    args.data_dir = os.path.expanduser(args.data_dir)
    args.result_dir = os.path.expanduser(args.result_dir)

    if torch.cuda.is_available() == False:
        raise Exception('At least one gpu must be available.')
    if args.num_gpus == 1:
        # train models in a single gpu
        gpu_cn = torch.device('cuda:0')
        gpu_cd = gpu_cn
    else:
        # train models in different two gpus
        gpu_cn = torch.device('cuda:0')
        gpu_cd = torch.device('cuda:1')

    # create result directory (if necessary)
    if os.path.exists(args.result_dir) == False:
        os.makedirs(args.result_dir)
    for s in ['phase_1', 'phase_2', 'phase_3']:
        if os.path.exists(os.path.join(args.result_dir, s)) == False:
            os.makedirs(os.path.join(args.result_dir, s))

    # dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm)
    train_loader = DataLoader(train_dset, batch_size=args.bsize, shuffle=True)

    # compute the mean pixel value of train dataset
    mean_pv = 0.
    imgpaths = train_dset.imgpaths[:min(args.max_mpv_samples, len(train_dset))]
    if args.comp_mpv:
        pbar = tqdm(total=len(imgpaths), desc='computing the mean pixel value')
        for imgpath in imgpaths:
            img = Image.open(imgpath)
            x = np.array(img, dtype=np.float32) / 255.
            mean_pv += x.mean()
            pbar.update()
        mean_pv /= len(imgpaths)
        pbar.close()
    mpv = torch.tensor(mean_pv).to(gpu_cn)

    # save training config
    args_dict = vars(args)
    args_dict['mean_pv'] = mean_pv
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # ================================================
    # Training Phase 1
    # ================================================
    # model & optimizer
    model_cn = CompletionNetwork()
    model_cn = model_cn.to(gpu_cn)
    if args.optimizer == 'adadelta':
        opt_cn = Adadelta(model_cn.parameters())
    else:
        opt_cn = Adam(model_cn.parameters())

    # training
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:

            opt_cn.zero_grad()

            # generate hole area
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            # merge x, mask, and mpv
            msg = 'phase 1 |'
            x = x.to(gpu_cn)
            msk = msk.to(gpu_cn)
            input = x - x * msk + mpv * msk
            output = model_cn(input)

            # optimize
            loss = completion_network_loss(x, output, msk)
            loss.backward()
            opt_cn.step()

            msg += ' train loss: %.5f' % loss.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_1 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_1',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cn.state_dict(),
                        os.path.join(args.result_dir, 'phase_1',
                                     'model_cn_step%d' % pbar.n))

            if pbar.n >= args.steps_1:
                break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    # model, optimizer & criterion
    model_cd = ContextDiscriminator(
        local_input_shape=(3, args.ld_input_size, args.ld_input_size),
        global_input_shape=(3, args.cn_input_size, args.cn_input_size),
    )
    model_cd = model_cd.to(gpu_cd)
    if args.optimizer == 'adadelta':
        opt_cd = Adadelta(model_cd.parameters())
    else:
        opt_cd = Adam(model_cd.parameters())
    criterion_cd = BCELoss()

    # training
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:

            x = x.to(gpu_cn)
            opt_cd.zero_grad()

            # ================================================
            # fake
            # ================================================
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            fake = torch.zeros((len(x), 1)).to(gpu_cd)
            msk = msk.to(gpu_cn)
            input_cn = x - x * msk + mpv * msk
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_fake = criterion_cd(output_fake, fake)

            # ================================================
            # real
            # ================================================
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            real = torch.ones((len(x), 1)).to(gpu_cd)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area)
            input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd))
            output_real = model_cd(input_real)
            loss_real = criterion_cd(output_real, real)

            # ================================================
            # optimize
            # ================================================
            loss = (loss_fake + loss_real) / 2.
            loss.backward()
            opt_cd.step()

            msg = 'phase 2 |'
            msg += ' train loss: %.5f' % loss.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_2 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_2',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cd.state_dict(),
                        os.path.join(args.result_dir, 'phase_2',
                                     'model_cd_step%d' % pbar.n))

            if pbar.n >= args.steps_2:
                break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    # training
    alpha = torch.tensor(args.alpha).to(gpu_cd)
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:

            x = x.to(gpu_cn)

            # ================================================
            # train model_cd
            # ================================================
            opt_cd.zero_grad()

            # fake
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            fake = torch.zeros((len(x), 1)).to(gpu_cd)
            msk = msk.to(gpu_cn)
            input_cn = x - x * msk + mpv * msk
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_cd_1 = criterion_cd(output_fake, fake)

            # real
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            real = torch.ones((len(x), 1)).to(gpu_cd)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area)
            input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd))
            output_real = model_cd(input_real)
            loss_cd_2 = criterion_cd(output_real, real)

            # optimize
            loss_cd = (loss_cd_1 + loss_cd_2) * alpha / 2.
            loss_cd.backward()
            opt_cd.step()

            # ================================================
            # train model_cn
            # ================================================
            opt_cn.zero_grad()

            loss_cn_1 = completion_network_loss(x, output_cn, msk).to(gpu_cd)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_cn_2 = criterion_cd(output_fake, real)

            # optimize
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.
            loss_cn.backward()
            opt_cn.step()

            msg = 'phase 3 |'
            msg += ' train loss (cd): %.5f' % loss_cd.cpu()
            msg += ' train loss (cn): %.5f' % loss_cn.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_3 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_3',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cn.state_dict(),
                        os.path.join(args.result_dir, 'phase_3',
                                     'model_cn_step%d' % pbar.n))
                    torch.save(
                        model_cd.state_dict(),
                        os.path.join(args.result_dir, 'phase_3',
                                     'model_cd_step%d' % pbar.n))

            if pbar.n >= args.steps_3:
                break
    pbar.close()
예제 #7
0
def main(args):

    args.model = os.path.expanduser(args.model)
    args.config = os.path.expanduser(args.config)
    args.input_img = os.path.expanduser(args.input_img)
    args.output_img = os.path.expanduser(args.output_img)

    # =============================================
    # Load model
    # =============================================
    with open(args.config, 'r') as f:
        config = json.load(f)
    mpv = torch.tensor(config['mpv']).view(1, 1, 1, 1)
    model = CompletionNetwork()
    model.load_state_dict(torch.load(args.model, map_location='cpu'))

    # =============================================
    # Predict
    # =============================================
    # convert img to tensor
    import torchvision as tv
    img = Image.open(args.input_img)
    #img = tv.transforms.Grayscale(num_output_channels=1),
    img = transforms.Resize(args.img_size)(img)
    img = transforms.RandomCrop((args.img_size, args.img_size))(img)
    x = transforms.ToTensor()(img)
    x = torch.unsqueeze(x, dim=0)

    # create mask
    mask = gen_input_mask(
        shape=(1, 1, x.shape[2], x.shape[3]),
        hole_size=(
            (args.hole_min_w, args.hole_max_w),
            (args.hole_min_h, args.hole_max_h),
        ),
        max_holes=args.max_holes,
    )
    #print(mask.shape)
    #print(mask)
    temp_str = str(args.input_img).replace("test", "masks")
    temp_index = len(temp_str) - 4

    out_img = torch.Tensor()
    for i in range(3):
        #print(mask_filename)
        mask_filename = temp_str[:temp_index] + '_mask' + str(
            i) + temp_str[temp_index:]
        mask_img = Image.open(mask_filename).convert('L')
        mask_img_inverted = PIL.ImageOps.invert(mask_img)
        #mask_transformed = mask_trans(mask_img_inverted)

        mask_trans = transforms.ToTensor()
        mask_transformed = mask_trans(mask_img)

        mask_shape = (1, 1, x.shape[2], x.shape[3])
        new_mask = torch.zeros(mask_shape)
        new_mask[0, 0, :, :] = mask_transformed
        mask = new_mask

        with torch.no_grad():
            x_mask = x - x * mask + mpv * mask
            input = torch.cat((x_mask, mask), dim=1)
            output = model(input)
            inpainted = poisson_blend(x, output, mask)
            binary_out = inpainted.clone()
            binary_out = gray_to_binary(binary_out)
            out_img = torch.cat((out_img, x, x_mask, inpainted, binary_out),
                                dim=0)
    save_image(out_img, args.output_img, nrow=4)

    print('output img was saved as %s.' % args.output_img)