Пример #1
0
def predict(model_path, input_img):

    model = CompletionNetwork()
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    img = input_img.resize((224, 224))
    x = transforms.ToTensor()(img)
    x = torch.unsqueeze(x, 0)
    # print(x.shape)

    model.eval()
    with torch.no_grad():
        output = model(x)
        # save_image(output, args.output_img, nrow=3)
    # print('output img was saved as %s.' % args.output_img)
    return transforms.ToPILImage()(output[0]).convert("RGB")
Пример #2
0
def main(args):

    args.model = os.path.expanduser(args.model)
    args.config = os.path.expanduser(args.config)
    args.input_img = os.path.expanduser(args.input_img)
    args.output_img = os.path.expanduser(args.output_img)

    # =============================================
    # Load model
    # =============================================
    with open(args.config, 'r') as f:
        config = json.load(f)
    mpv = torch.tensor(config['mpv']).view(1, 3, 1, 1)
    model = CompletionNetwork()
    model.load_state_dict(torch.load(args.model, map_location='cpu'))

    # =============================================
    # Predict
    # =============================================
    # convert img to tensor
    img = Image.open(args.input_img)
    img = transforms.Resize(args.img_size)(img)
    img = transforms.RandomCrop((args.img_size, args.img_size))(img)
    x = transforms.ToTensor()(img)
    x = torch.unsqueeze(x, dim=0)

    # create mask
    mask = gen_input_mask(
        shape=(1, 1, x.shape[2], x.shape[3]),
        hole_size=(
            (args.hole_min_w, args.hole_max_w),
            (args.hole_min_h, args.hole_max_h),
        ),
        max_holes=args.max_holes,
    )

    # inpaint
    model.eval()
    with torch.no_grad():
        x_mask = x - x * mask + mpv * mask
        input = torch.cat((x_mask, mask), dim=1)
        output = model(input)
        inpainted = poisson_blend(x_mask, output, mask)
        imgs = torch.cat((x, x_mask, inpainted), dim=0)
        save_image(imgs, args.output_img, nrow=3)
    print('output img was saved as %s.' % args.output_img)
Пример #3
0
def main(args):

    args.model = os.path.expanduser(args.model)
    args.config = os.path.expanduser(args.config)
    args.input_img = os.path.expanduser(args.input_img)
    args.output_img = os.path.expanduser(args.output_img)

    # =============================================
    # Load model
    # =============================================
    with open(args.config, 'r') as f:
        config = json.load(f)
    config['mpv'] = 0.13465263
    mpv = torch.tensor(config['mpv']).view(1, 1, 1, 1)
    model = CompletionNetwork()
    model.load_state_dict(torch.load(args.model, map_location='cpu'))

    # =============================================
    # Predict
    # =============================================
    # convert img to tensor
    img = Image.open(args.input_img)
    x = transforms.ToTensor()(img)
    x = torch.unsqueeze(x, dim=0)

    # create mask
    mask = gen_input_mask(shape=(1, 1, x.shape[2], x.shape[3]))

    # inpaint
    model.eval()
    with torch.no_grad():
        x_mask = x - x * mask + mpv * mask
        input = torch.cat((x_mask, mask), dim=1)
        output = model(input)
        inpainted = rejoiner(x_mask, output, mask)
        imgs = torch.cat((x, x_mask, inpainted), dim=0)
        save_image(imgs, args.output_img, nrow=3)
    print('output img was saved as %s.' % args.output_img)
Пример #4
0
def main(args):
    # ================================================
    # Preparation
    # ================================================
    if not torch.cuda.is_available():
        raise Exception('At least one gpu must be available.')
    gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)
    for phase in ['phase_1', 'phase_2', 'phase_3']:
        if not os.path.exists(os.path.join(args.result_dir, phase)):
            os.makedirs(os.path.join(args.result_dir, phase))

    # load dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'),
                              trnsfm,
                              recursive_search=args.recursive_search)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'),
                             trnsfm,
                             recursive_search=args.recursive_search)
    train_loader = DataLoader(train_dset,
                              batch_size=(args.bsize // args.bdivs),
                              shuffle=True)

    # compute mpv (mean pixel value) of training dataset
    if args.mpv is None:
        mpv = np.zeros(shape=(1, ))
        pbar = tqdm(total=len(train_dset.imgpaths),
                    desc='computing mean pixel value of training dataset...')
        for imgpath in train_dset.imgpaths:
            img = Image.open(imgpath)
            x = np.array(img) / 255.
            mpv += x.mean(axis=(0, 1))
            pbar.update()
        mpv /= len(train_dset.imgpaths)
        pbar.close()
    else:
        mpv = np.array(args.mpv)

    # save training config
    mpv_json = []
    for i in range(1):
        mpv_json.append(float(mpv[i]))
    args_dict = vars(args)
    # args_dict['mpv'] = mpv_json
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # make mpv & alpha tensors
    mpv = torch.tensor(mpv.reshape(1, 1, 1, 1), dtype=torch.float32).to(gpu)
    alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu)

    # ================================================
    # Training Phase 1
    # ================================================
    # load completion network
    model_cn = CompletionNetwork()
    if args.init_model_cn is not None:
        model_cn.load_state_dict(
            torch.load(args.init_model_cn, map_location='cpu'))
    if args.data_parallel:
        model_cn = DataParallel(model_cn)
    model_cn = model_cn.to(gpu)
    opt_cn = Adadelta(model_cn.parameters())

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:
            # forward
            x = x.to(gpu)
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input = torch.cat((x_mask, mask), dim=1)
            output = model_cn(input)
            loss = completion_network_loss(x, output, mask)

            # backward
            loss.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cn.step()
                opt_cn.zero_grad()
                pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
                pbar.update()

                # test
                if pbar.n % args.snaperiod_1 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_1',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_1',
                            'model_cn_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                    model_cn.train()
                if pbar.n >= args.steps_1:
                    break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    # load context discriminator
    model_cd = ContextDiscriminator(
        local_input_shape=(1, args.ld_input_size, args.ld_input_size),
        global_input_shape=(1, args.cn_input_size, args.cn_input_size),
    )
    if args.init_model_cd is not None:
        model_cd.load_state_dict(
            torch.load(args.init_model_cd, map_location='cpu'))
    if args.data_parallel:
        model_cd = DataParallel(model_cd)
    model_cd = model_cd.to(gpu)
    opt_cd = Adadelta(model_cd.parameters(), lr=0.1)
    bceloss = BCELoss()

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:
            # fake forward
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd(
                (input_ld_fake.to(gpu), input_gd_fake.to(gpu)))
            loss_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_real = bceloss(output_real, real)

            # reduce
            loss = (loss_fake + loss_real) / 2.

            # backward
            loss.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cd.step()
                opt_cd.zero_grad()
                pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu())
                pbar.update()

                # test
                if pbar.n % args.snaperiod_2 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_2',
                                               'step%d.png' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_2',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cd.state_dict(), model_cd_path)
                    model_cn.train()
                if pbar.n >= args.steps_2:
                    break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:
            # forward model_cd
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)

            # fake forward
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, input_gd_fake))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                # optimize
                opt_cd.step()
                opt_cd.zero_grad()

            # forward model_cn
            loss_cn_1 = completion_network_loss(x, output_cn, mask)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, (input_gd_fake)))
            loss_cn_2 = bceloss(output_fake, real)

            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.

            # backward model_cn
            loss_cn.backward()
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cn.step()
                opt_cn.zero_grad()
                pbar.set_description(
                    'phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()

                # test
                if pbar.n % args.snaperiod_3 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cn_step%d' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                            torch.save(model_cd.state_dict(), model_cd_path)
                    model_cn.train()
                if pbar.n >= args.steps_3:
                    break
    pbar.close()
Пример #5
0
def GAN_patching_inputs(images, predicted):  # images and its predicted tensors
    global N
    model = CompletionNetwork()
    model.load_state_dict(torch.load("cifar10_inpainting",
                                     map_location='cuda'))
    model.eval()
    model = model.to(device)
    cleanimgs = list(range(len(images)))  # GAN inpainted
    # This is to apply Grad CAM to the load images
    # --------------------------------------------
    for j in range(len(images)):
        N += 1
        image = images[j]
        image = unnormalize(image)  # unnormalize to [0 1] to feed into GAN
        image = torch.unsqueeze(image,
                                0)  # unsqueeze meaning adding 1D to the tensor

        start_time = time.time()
        mask = gcam(image)  # get the mask through GradCAM

        cond_mask = mask >= MASK_COND
        mask = cond_mask.astype(int)

        # ---------------------------------------

        mask = np.expand_dims(mask, axis=0)  # add 1D to mask
        mask = np.expand_dims(mask, axis=0)
        mask = torch.tensor(mask)  # convert mask to tensor 1,1,32,32
        mask = mask.type(torch.FloatTensor)
        mask = mask.to(device)
        x = image  # original test image

        mpv = [0.4914655575466156, 0.4821903321331739, 0.4465675537097454]
        mpv = torch.tensor(mpv).view(1, 3, 1, 1)
        mpv = mpv.to(device)
        # inpaint
        with torch.no_grad():
            x_mask = x - x * mask + mpv * mask  # generate the occluded input [0 1]
            inputx = torch.cat((x_mask, mask), dim=1)
            output = model(
                inputx)  # generate the output for the occluded input [0 1]
            end_time = time.time()
            GAN_process_time = 1000.0 * (end_time - start_time
                                         )  # convert to ms
            GAN_process_time = round(GAN_process_time, 3)
            np.savetxt('runtime.csv', (N, GAN_process_time), delimiter=',')

            # image restoration
            inpainted = poisson_blend_old(x_mask, output,
                                          mask)  # this is GAN output [0 1]
            inpainted = inpainted.to(device)

            # store GAN output
            clean_input = inpainted
            clean_input = normalize_tensor_batch(
                clean_input)  # normalize to [-1 1]
            clean_input = torch.squeeze(
                clean_input)  # remove the 1st dimension
            cleanimgs[j] = clean_input.cpu().numpy()  # store to a list

    # this is tensor for GAN output
    cleanimgs_tensor = torch.from_numpy(np.asarray(cleanimgs))
    cleanimgs_tensor = cleanimgs_tensor.type(torch.FloatTensor)
    cleanimgs_tensor = cleanimgs_tensor.to(device)

    return cleanimgs_tensor
Пример #6
0
def main(args):

    if args.wandb != "tmp":
        wandb.init(project=args.wandb, config=args)

    if not torch.cuda.is_available():
        raise Exception('At least one gpu must be available.')
    gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)
    for phase in ['phase_1', 'phase_2', 'phase_3']:
        if not os.path.exists(os.path.join(args.result_dir, phase)):
            os.makedirs(os.path.join(args.result_dir, phase))

    # load dataset

    print('loading dataset... (it may take a few minutes)')
    train_dset = masked_dataset(os.path.join(args.data_dir, 'train'),
                                args.max_train)

    test_dset = masked_dataset(os.path.join(args.data_dir, 'test'),
                               args.max_test)

    train_loader = DataLoader(train_dset,
                              batch_size=(args.batch_size),
                              shuffle=True)

    alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu)

    model_cn = CompletionNetwork()
    if args.init_model_cn is not None:
        model_cn.load_state_dict(
            torch.load(args.init_model_cn, map_location='cpu'))

    model_cn = model_cn.to(gpu)
    model_cn.train()
    opt_cn = Adam(model_cn.parameters(), lr=args.learning_rate)

    # training
    # ================================================
    # Training Phase 1
    # ================================================
    pbar = tqdm(total=args.steps_1)

    epochs = 0
    while epochs < args.steps_1:
        for i, (normal, masked) in tqdm(enumerate(train_loader, 0)):
            # forward
            # normal = torch.autograd.Variable(normal,requires_grad=True).to(gpu)
            # masked = torch.autograd.Variable(normal,requires_grad=True).to(gpu)

            output = model_cn(masked.to(gpu))
            loss = torch.nn.functional.mse_loss(output, normal.to(gpu))

            # backward
            loss.backward()
            # optimize
            opt_cn.step()
            opt_cn.zero_grad()

            if args.wandb != "tmp":
                wandb.log({"phase_1_train_loss": loss.cpu()})
            pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
        pbar.update()

        # test

        model_cn.eval()
        with torch.no_grad():
            normal, masked = sample_random_batch(
                test_dset, batch_size=args.num_test_completions)
            normal = normal.to(gpu)
            masked = masked.to(gpu)
            output = model_cn(masked)

            # completed = output
            imgs = torch.cat((masked.cpu(), normal.cpu(), output.cpu()), dim=0)
            imgpath = os.path.join(args.result_dir, 'phase_1',
                                   'step%d.png' % pbar.n)
            model_cn_path = os.path.join(args.result_dir, 'phase_1',
                                         'model_cn_step%d' % pbar.n)
            save_image(imgs, imgpath, nrow=len(masked))

            torch.save(model_cn.state_dict(), model_cn_path)
        model_cn.train()
        epochs += 1
    pbar.close()