Esempio n. 1
0
def main(args):
    # ================================================
    # Preparation
    # ================================================
    if not torch.cuda.is_available():
        raise Exception('At least one gpu must be available.')
    gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)
    for phase in ['phase_1', 'phase_2', 'phase_3']:
        if not os.path.exists(os.path.join(args.result_dir, phase)):
            os.makedirs(os.path.join(args.result_dir, phase))

    # load dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'),
                              trnsfm,
                              recursive_search=args.recursive_search)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'),
                             trnsfm,
                             recursive_search=args.recursive_search)
    train_loader = DataLoader(train_dset,
                              batch_size=(args.bsize // args.bdivs),
                              shuffle=True)

    # compute mpv (mean pixel value) of training dataset
    if args.mpv is None:
        mpv = np.zeros(shape=(1, ))
        pbar = tqdm(total=len(train_dset.imgpaths),
                    desc='computing mean pixel value of training dataset...')
        for imgpath in train_dset.imgpaths:
            img = Image.open(imgpath)
            x = np.array(img) / 255.
            mpv += x.mean(axis=(0, 1))
            pbar.update()
        mpv /= len(train_dset.imgpaths)
        pbar.close()
    else:
        mpv = np.array(args.mpv)

    # save training config
    mpv_json = []
    for i in range(1):
        mpv_json.append(float(mpv[i]))
    args_dict = vars(args)
    # args_dict['mpv'] = mpv_json
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # make mpv & alpha tensors
    mpv = torch.tensor(mpv.reshape(1, 1, 1, 1), dtype=torch.float32).to(gpu)
    alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu)

    # ================================================
    # Training Phase 1
    # ================================================
    # load completion network
    model_cn = CompletionNetwork()
    if args.init_model_cn is not None:
        model_cn.load_state_dict(
            torch.load(args.init_model_cn, map_location='cpu'))
    if args.data_parallel:
        model_cn = DataParallel(model_cn)
    model_cn = model_cn.to(gpu)
    opt_cn = Adadelta(model_cn.parameters())

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:
            # forward
            x = x.to(gpu)
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input = torch.cat((x_mask, mask), dim=1)
            output = model_cn(input)
            loss = completion_network_loss(x, output, mask)

            # backward
            loss.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cn.step()
                opt_cn.zero_grad()
                pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
                pbar.update()

                # test
                if pbar.n % args.snaperiod_1 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_1',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_1',
                            'model_cn_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                    model_cn.train()
                if pbar.n >= args.steps_1:
                    break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    # load context discriminator
    model_cd = ContextDiscriminator(
        local_input_shape=(1, args.ld_input_size, args.ld_input_size),
        global_input_shape=(1, args.cn_input_size, args.cn_input_size),
    )
    if args.init_model_cd is not None:
        model_cd.load_state_dict(
            torch.load(args.init_model_cd, map_location='cpu'))
    if args.data_parallel:
        model_cd = DataParallel(model_cd)
    model_cd = model_cd.to(gpu)
    opt_cd = Adadelta(model_cd.parameters(), lr=0.1)
    bceloss = BCELoss()

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:
            # fake forward
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd(
                (input_ld_fake.to(gpu), input_gd_fake.to(gpu)))
            loss_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_real = bceloss(output_real, real)

            # reduce
            loss = (loss_fake + loss_real) / 2.

            # backward
            loss.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cd.step()
                opt_cd.zero_grad()
                pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu())
                pbar.update()

                # test
                if pbar.n % args.snaperiod_2 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_2',
                                               'step%d.png' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_2',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cd.state_dict(), model_cd_path)
                    model_cn.train()
                if pbar.n >= args.steps_2:
                    break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:
            # forward model_cd
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                         x.shape[3]), ).to(gpu)

            # fake forward
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, input_gd_fake))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()
            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                # optimize
                opt_cd.step()
                opt_cd.zero_grad()

            # forward model_cn
            loss_cn_1 = completion_network_loss(x, output_cn, mask)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, (input_gd_fake)))
            loss_cn_2 = bceloss(output_fake, real)

            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.

            # backward model_cn
            loss_cn.backward()
            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0

                # optimize
                opt_cn.step()
                opt_cn.zero_grad()
                pbar.set_description(
                    'phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()

                # test
                if pbar.n % args.snaperiod_3 == 0:
                    model_cn.eval()
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(shape=(x.shape[0], 1, x.shape[2],
                                                     x.shape[3]), ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = rejoiner(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cn_step%d' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                            torch.save(model_cd.state_dict(), model_cd_path)
                    model_cn.train()
                if pbar.n >= args.steps_3:
                    break
    pbar.close()
Esempio n. 2
0
def main(args):

    if args.wandb != "tmp":
        wandb.init(project=args.wandb, config=args)

    if not torch.cuda.is_available():
        raise Exception('At least one gpu must be available.')
    gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)
    for phase in ['phase_1', 'phase_2', 'phase_3']:
        if not os.path.exists(os.path.join(args.result_dir, phase)):
            os.makedirs(os.path.join(args.result_dir, phase))

    # load dataset

    print('loading dataset... (it may take a few minutes)')
    train_dset = masked_dataset(os.path.join(args.data_dir, 'train'),
                                args.max_train)

    test_dset = masked_dataset(os.path.join(args.data_dir, 'test'),
                               args.max_test)

    train_loader = DataLoader(train_dset,
                              batch_size=(args.batch_size),
                              shuffle=True)

    alpha = torch.tensor(args.alpha, dtype=torch.float32).to(gpu)

    model_cn = CompletionNetwork()
    if args.init_model_cn is not None:
        model_cn.load_state_dict(
            torch.load(args.init_model_cn, map_location='cpu'))

    model_cn = model_cn.to(gpu)
    model_cn.train()
    opt_cn = Adam(model_cn.parameters(), lr=args.learning_rate)

    # training
    # ================================================
    # Training Phase 1
    # ================================================
    pbar = tqdm(total=args.steps_1)

    epochs = 0
    while epochs < args.steps_1:
        for i, (normal, masked) in tqdm(enumerate(train_loader, 0)):
            # forward
            # normal = torch.autograd.Variable(normal,requires_grad=True).to(gpu)
            # masked = torch.autograd.Variable(normal,requires_grad=True).to(gpu)

            output = model_cn(masked.to(gpu))
            loss = torch.nn.functional.mse_loss(output, normal.to(gpu))

            # backward
            loss.backward()
            # optimize
            opt_cn.step()
            opt_cn.zero_grad()

            if args.wandb != "tmp":
                wandb.log({"phase_1_train_loss": loss.cpu()})
            pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
        pbar.update()

        # test

        model_cn.eval()
        with torch.no_grad():
            normal, masked = sample_random_batch(
                test_dset, batch_size=args.num_test_completions)
            normal = normal.to(gpu)
            masked = masked.to(gpu)
            output = model_cn(masked)

            # completed = output
            imgs = torch.cat((masked.cpu(), normal.cpu(), output.cpu()), dim=0)
            imgpath = os.path.join(args.result_dir, 'phase_1',
                                   'step%d.png' % pbar.n)
            model_cn_path = os.path.join(args.result_dir, 'phase_1',
                                         'model_cn_step%d' % pbar.n)
            save_image(imgs, imgpath, nrow=len(masked))

            torch.save(model_cn.state_dict(), model_cn_path)
        model_cn.train()
        epochs += 1
    pbar.close()