Exemple #1
0
def main(args):
    G_1 = Generator_lr(in_channels=3)
    SR = EDSR(n_colors=3)

    # load pretrained model
    G_1.load_state_dict(
        torch.load(os.path.join(args.weights_dir, 'final_weights_G_1.pkl')))
    SR.load_state_dict(
        torch.load(os.path.join(args.weights_dir, 'final_weights_SR.pkl')))

    G_1.cuda()
    G_1.eval()
    SR.cuda()
    SR.eval()

    # predict
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, 'clean'), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, 'SR'), exist_ok=True)
    for image_name in tqdm(os.listdir(args.data_path)):
        # read file
        image = Image.open(os.path.join(args.data_path, image_name))
        # denoise
        clean_image = resolv_deonoise(G_1, image)
        clean_image.save(os.path.join(args.output_dir, 'clean', image_name))
        # SR
        sr_image = resolv_sr(G_1, SR, image)
        sr_image.save(os.path.join(args.output_dir, 'SR', image_name))
Exemple #2
0
def main(args):
    os.makedirs(args.log_dir, exist_ok=True)

    # create models
    G_1 = Generator_lr(in_channels=3)
    G_2 = Generator_lr(in_channels=3)
    D_1 = Discriminator_lr(in_channels=3, in_h=16, in_w=16)
    SR = EDSR(n_colors=3)
    G_3 = Generator_sr(in_channels=3)
    D_2 = Discriminator_sr(in_channels=3, in_h=64, in_w=64)

    for model in [G_1, G_2, D_1, SR, G_3, D_2]:
        model.cuda()
        model.train()

    # tensorboard
    writer = SummaryWriter(log_dir=args.log_dir)

    # create optimizors
    optim = {
        'G_1':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       G_1.parameters()),
                         lr=args.lr * 5),
        'G_2':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       G_2.parameters()),
                         lr=args.lr * 5),
        'D_1':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       D_1.parameters()),
                         lr=args.lr),
        'SR':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       SR.parameters()),
                         lr=args.lr * 5),
        'G_3':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       G_3.parameters()),
                         lr=args.lr),
        'D_2':
        torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                       D_2.parameters()),
                         lr=args.lr)
    }
    for key in optim.keys():
        optim[key].zero_grad()

    # get dataloader
    train_dataset = DIV2KDataset(root=args.data_path)
    trainloader = DataLoader(train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=3)

    print('-' * 20)
    print('Start training')
    print('-' * 20)
    iter_index = 0
    for epoch in range(args.epochs):
        G_1.train()
        SR.train()
        start = timeit.default_timer()
        for _, batch in enumerate(trainloader):
            iter_index += 1
            image, label_hr, label_lr = batch
            image = image.cuda()
            label_hr = label_hr.cuda()
            label_lr = label_lr.cuda()
            '''loss for lr GAN'''
            '''update G_1 and G_2'''
            for key in optim.keys():
                optim[key].zero_grad()
            # D loss for D_1
            image_clean = G_1(image)
            loss_D1 = discriminator_loss(discriminator=D_1,
                                         fake=image_clean,
                                         real=label_lr)
            loss_D1.backward()
            optim['D_1'].step()

            # GD loss for G_1
            loss_G1 = generator_discriminator_loss(generator=G_1,
                                                   discriminator=D_1,
                                                   input=image)
            loss_G1.backward()

            # cycle loss for G_1 and G_2
            loss_cycle = 10 * cycle_loss(G_1, G_2, image)
            loss_cycle.backward()

            # idt loss for G_1
            loss_idt = 5 * identity_loss(clean_image=label_lr, generator=G_1)
            loss_idt.backward()

            # tvloss for G_1
            loss_tv = 0.5 * tvloss(input=image, generator=G_1)
            loss_tv.backward()

            # optimize G_1 and G_2
            optim['G_1'].step()
            optim['G_2'].step()

            if iter_index % 100 == 0:
                print(
                    'iter {}: LR: loss_D1={}, loss_GD={}, loss_cycle={}, loss_idt={}, loss_tv={}'
                    .format(iter_index, loss_D1.item(), loss_G1.item(),
                            loss_cycle.item(), loss_idt.item(),
                            loss_tv.item()))
                writer.add_scalar('LR/loss_D1', loss_D1.item(),
                                  iter_index // 100)
                writer.add_scalar('LR/loss_GD', loss_G1.item(),
                                  iter_index // 100)
                writer.add_scalar('LR/loss_cycle', loss_cycle.item(),
                                  iter_index // 100)
                writer.add_scalar('LR/loss_idt', loss_idt.item(),
                                  iter_index // 100)
                writer.add_scalar('LR/loss_tv', loss_tv.item(),
                                  iter_index // 100)
                writer.add_image('LR/origin', image[0], iter_index // 100)
                writer.add_image('LR/denoise',
                                 G_1(image)[0], iter_index // 100)
            '''loss for sr GAN'''
            '''update G_1, SR and G_3'''
            for key in optim.keys():
                optim[key].zero_grad()
            image_clean = G_1(image).detach()
            # D loss for D_2
            image_sr = SR(image_clean)
            loss_D2 = discriminator_loss(discriminator=D_2,
                                         fake=image_sr,
                                         real=label_hr)
            loss_D2.backward()
            optim['D_2'].step()

            # GD loss for SR
            loss_SR = generator_discriminator_loss(generator=SR,
                                                   discriminator=D_2,
                                                   input=image_clean)
            loss_SR.backward()

            # cycle loss for SR and G_3
            loss_cycle = 10 * cycle_loss(SR, G_3, image_clean)
            loss_cycle.backward()

            # idt loss for SR
            loss_idt = 5 * identity_loss_sr(
                clean_image_lr=label_lr, clean_image_hr=label_hr, generator=SR)
            loss_idt.backward()

            # tvloss for SR
            loss_tv = 0.5 * tvloss(input=image_clean, generator=SR)
            loss_tv.backward()

            # optimize G_1, SR and G_3
            optim['G_1'].step()
            optim['SR'].step()
            optim['G_3'].step()

            if iter_index % 100 == 0:
                print(
                    '         SR: loss_D2={}, loss_SR={}, loss_cycle={}, loss_idt={}, loss_tv={}'
                    .format(loss_D2.item(), loss_SR.item(), loss_cycle.item(),
                            loss_idt.item(), loss_tv.item()))
                writer.add_scalar('SR/loss_D2', loss_D2.item(),
                                  iter_index // 100)
                writer.add_scalar('SR/loss_SR', loss_SR.item(),
                                  iter_index // 100)
                writer.add_scalar('SR/loss_cycle', loss_cycle.item(),
                                  iter_index // 100)
                writer.add_scalar('SR/loss_idt', loss_idt.item(),
                                  iter_index // 100)
                writer.add_scalar('SR/loss_tv', loss_tv.item(),
                                  iter_index // 100)
                writer.add_image('SR/origin', image[0], iter_index // 100)
                writer.add_image('SR/clean_image',
                                 G_1(image)[0], iter_index // 100)
                writer.add_image('SR/SR', SR(G_1(image))[0], iter_index // 100)
                writer.flush()

        end = timeit.default_timer()
        print('epoch {}, using {} seconds'.format(epoch, end - start))

        G_1.eval()
        SR.eval()
        image = Image.open('/data/data/DIV2K/unsupervised/lr/0001x4d.png')
        sr_image = resolv_sr(G_1, SR, image)
        # image_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).cuda()
        # sr_image_tensor = SR(G_1(image_tensor).detach())
        # sr_image = torchvision.transforms.functional.to_pil_image(sr_image_tensor[0].cpu())
        sr_image.save(
            os.path.join(args.log_dir, '0001x4d_sr_{}.png'.format(str(epoch))))

        torch.save(G_1.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_1.pkl'))
        torch.save(G_2.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_2.pkl'))
        torch.save(D_1.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_D_1.pkl'))
        torch.save(SR.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_SR.pkl'))
        torch.save(G_3.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_G_3.pkl'))
        torch.save(D_2.state_dict(),
                   os.path.join(args.log_dir, 'ep-' + str(epoch) + '_D_2.pkl'))

    writer.close()
    print('Training done.')
    torch.save(G_1.state_dict(),
               os.path.join(args.log_dir, 'final_weights_G_1.pkl'))
    torch.save(G_2.state_dict(),
               os.path.join(args.log_dir, 'final_weights_G_2.pkl'))
    torch.save(D_1.state_dict(),
               os.path.join(args.log_dir, 'final_weights_D_1.pkl'))
    torch.save(SR.state_dict(),
               os.path.join(args.log_dir, 'final_weights_SR.pkl'))
    torch.save(G_3.state_dict(),
               os.path.join(args.log_dir, 'final_weights_G_3.pkl'))
    torch.save(D_2.state_dict(),
               os.path.join(args.log_dir, 'final_weights_D_2.pkl'))

    image = Image.open('/data/data/DIV2K/unsupervised/lr/0001x4d.png')
    image.save(os.path.join(args.log_dir, '0001x4d.png'))
    sr_image = resolv_sr(G_1, SR, image)
    # image_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).cuda()
    # sr_image_tensor = SR(G_1(image_tensor))
    # sr_image = torchvision.transforms.functional.to_pil_image(sr_image_tensor[0].cpu())
    sr_image.save(os.path.join(args.log_dir, '0001x4d_sr.png'))
Exemple #3
0
def create_reconstruction():
    input = Input(shape=(32, 32, 3))
    model = Model(inputs=input,
                  outputs=EDSR(input, config.filters, config.nBlocks))
    return model
def load_edsr(device, n_resblocks=16, n_feats=64, model_details=True):
    """
    Loads the EDSR model

    Parameters
    ----------
    device : str
        device type.
    n_resblocks : int, optional
        number of res_blocks. The default is 16.
    n_feats : int, optional
        number of features. The default is 64.

    Returns
    -------
    model : torch.nn.model
        EDSR model.

    """
    args = {
        "G0": 64,
        "RDNconfig": "B",
        "RDNkSize": 3,
        "act": "relu",
        "batch_size": 16,
        "betas": (0.9, 0.999),
        "chop": True,
        "cpu": True,
        "data_range": "1-800/801-810",
        "data_test": ["Demo"],
        "data_train": ["DIV2K"],
        "debug": False,
        "decay": "200",
        "dilation": False,
        "dir_data": "../../../dataset",
        "dir_demo": "../test",
        "epochs": 300,
        "epsilon": 1e-08,
        "ext": "sep",
        "extend": ".",
        "gamma": 0.5,
        "gan_k": 1,
        "gclip": 0,
        "load": "",
        "loss": "1*L1",
        "lr": 0.0001,
        "model": "EDSR",
        "momentum": 0.9,
        "n_GPUs": 1,
        "n_colors": 3,
        "n_feats": 64,
        "n_resblocks": 16,
        "n_resgroups": 10,
        "n_threads": 6,
        "no_augment": False,
        "optimizer": "ADAM",
        "patch_size": 192,
        "pre_train": "download",
        "precision": "single",
        "print_every": 100,
        "reduction": 16,
        "res_scale": 1,
        "reset": False,
        "resume": 0,
        "rgb_range": 255,
        "save": "test",
        "save_gt": False,
        "save_models": False,
        "save_results": True,
        "scale": [4],
        "seed": 1,
        "self_ensemble": False,
        "shift_mean": True,
        "skip_threshold": 100000000.0,
        "split_batch": 1,
        "template": ".",
        "test_every": 1000,
        "test_only": True,
        "weight_decay": 0,
    }
    model = edsr.make_model(args).to(device)
    edsr.load(model)
    if model_details:
        pass
    return model
if not os.path.exists("data"):
    print("Downloading flower dataset...")
    subprocess.check_output(
        "mkdir data && curl https://storage.googleapis.com/wandb/flower-enhance.tar.gz | tar xz -C data",
        shell=True)

config.steps_per_epoch = len(
    glob.glob(config.train_dir + "/*-in.jpg")) // config.batch_size
config.val_steps_per_epoch = len(
    glob.glob(config.val_dir + "/*-in.jpg")) // config.batch_size

# Neural network
input1 = Input(shape=(config.input_height, config.input_width, 3),
               dtype='float32')
model = Model(inputs=input1,
              outputs=EDSR(input1, config.filters, config.nBlocks))

#print(model.summary())
#model.load_weights('edsr.h5')

#es = EarlyStopping(monitor='val_perceptual_distance', mode='min', verbose = 1, patience=2)
mc = ModelCheckpoint('edsr.h5',
                     monitor='val_perceptual_distance',
                     mode='min',
                     save_best_only=True)

##DONT ALTER metrics=[perceptual_distance]
model.compile(optimizer='adam',
              loss=[perceptual_distance],
              metrics=[perceptual_distance])
def main(args):
    # Create directories if it's not  hyper-optimisation round.
    if not args.is_optimisation:
        results_directory = f'results/result_{args.experiment_num}'
        os.makedirs('images', exist_ok=True)
        os.makedirs(results_directory, exist_ok=True)
        # Save arguments for experiment reproducibility.
        with open(os.path.join(results_directory, 'arguments.txt'), 'w') \
                as file:
            json.dump(args.__dict__, file, indent=2)

    # Set size for plots.
    plt.rcParams['figure.figsize'] = (10, 10)

    # Select the device to train the model on.
    device = torch.device(args.device)

    # Load the dataset.
    # TODO : Add normalisation  transforms.Normalize(
    #   torch.tensor(-4.4713e-07).float(),
    #   torch.tensor(0.1018).float())
    # TODO: Add more data augmentation transforms.
    data_transforms = transforms.Compose([
        #  RandomHorizontalFlip(),
        ToTensor()
    ])

    dataset = Data(args.filename_x,
                   args.filename_y,
                   args.data_root,
                   transform=data_transforms)

    if not args.is_optimisation:
        print(f"Data sizes, input: {dataset.input_dim}, output: "
              f"{dataset.output_dim}, Fk: {dataset.output_dim_fk}")

    train_data, test_data = split_dataset(
        dataset, args.test_percentage + args.val_percentage)
    test_data, val_data = split_dataset(test_data, 0.5)

    # Initialize generator model.
    if args.model == 'SRCNN':
        generator = SRCNN(input_dim=dataset.input_dim,
                          output_dim=dataset.output_dim).to(device)
    elif args.model == 'EDSR':
        generator = EDSR(args.latent_dim,
                         args.num_res_blocks,
                         output_dim=dataset.output_dim).to(device)
    elif args.model == 'VDSR':
        generator = VDSR(args.latent_dim,
                         args.num_res_blocks,
                         output_dim=dataset.output_dim).to(device)

    # Optimizers
    optim_G = optim.Adam(generator.parameters(), lr=args.lr)

    scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optim_G, patience=args.scheduler_patience, verbose=True)

    # Initialize optional Fk discriminator and optimizer.

    # losses type
    criterion_dictionary = {
        "MSE": nn.MSELoss(),
        "L1": nn.L1Loss(),
    }
    reconstruction_criterion = criterion_dictionary[args.criterion_type]

    # Initialize a dict of empty lists for plotting.
    plot_log = defaultdict(list)

    for epoch in range(args.n_epochs):
        # Train model for one epoch.
        loss = iter_epoch((generator), (optim_G),
                          train_data,
                          device,
                          batch_size=args.batch_size,
                          reconstruction_criterion=reconstruction_criterion,
                          use_fk_loss=args.use_fk_loss)

        # Report model performance.
        if not args.is_optimisation:
            print(f"Epoch: {epoch}, Loss: {loss['G']}, "
                  f"PSNR: {loss['psnr']}")  # SSIM: {loss['ssim']}")
        plot_log['G'].append(loss['G'])

        # Model evaluation every eval_iteration and last iteration.
        if epoch % args.eval_interval == 0 \
                or (args.is_optimisation and epoch == args.n_epochs - 1):
            loss_val = iter_epoch(
                (generator), (None),
                val_data,
                device,
                batch_size=args.batch_size,
                eval=True,
                reconstruction_criterion=reconstruction_criterion,
                use_fk_loss=args.use_fk_loss)
            if not args.is_optimisation:
                print(f"Validation on epoch: {epoch}, Loss: {loss_val['G']}, "
                      f" PSNR: {loss_val['psnr']}"
                      )  #, SSIM: {loss_val['ssim']}")

            plot_log['G_val'].append(loss_val['G'])
            plot_log['psnr_val'].append(loss_val['psnr'])
            # plot_log['ssim_val'].append(loss_val['ssim'])

            # Update scheduler based on PSNR or separate model losses.
            if args.is_psnr_step:
                scheduler_g.step(loss_val['psnr'])

            else:
                scheduler_g.step(loss_val['G'])

            if not args.is_optimisation:
                pass
                # save_loss_plot(plot_log['G_val'], results_directory, is_val=True)

        if not args.is_optimisation:
            # Plot results.
            if epoch % args.save_interval == 0:
                plot_samples(generator, val_data, epoch, device,
                             results_directory)
                plot_samples(generator,
                             train_data,
                             epoch,
                             device,
                             results_directory,
                             is_train=True)

            save_loss_plot(plot_log['G'], results_directory)

    if not args.is_optimisation:
        # Save the trained generator model.
        torch.save(generator, os.path.join(results_directory, 'generator.pth'))

        if args.save_test_dataset:
            sets_name = ['test', 'val', 'train']
            sets = [test_data, val_data, train_data]
            for name, d_set in zip(sets_name, sets):
                list_x = []
                list_y = []
                for sample in d_set:
                    list_x.append(sample['x'].unsqueeze(0))
                    list_y.append(sample['y'].unsqueeze(0))
                tensor_x = torch.cat(list_x, 0)
                tensor_y = torch.cat(list_y, 0)
                data_folder_for_results = 'final/data'
                os.makedirs(data_folder_for_results, exist_ok=True)
                torch.save(
                    tensor_x,
                    f'{data_folder_for_results}/{name}_data_x_{args.experiment_num}.pt'
                )
                torch.save(
                    tensor_y,
                    f'{data_folder_for_results}/{name}_data_y_{args.experiment_num}.pt'
                )

        return plot_log, generator, test_data
    if args.is_optimisation:
        __, test_data = random_split(test_data, [len(test_data) - 2, 2])
        return plot_log, generator, test_data