Пример #1
0
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = get_dataset(args.dataset_type, args.input, *args.hw, args.factor, pre=args.pre_factor, threshold=args.E_thres, N=args.n_hardest)
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=not args.no_shuffle,
        num_workers=0
    )
    generator = GeneratorRRDB(1, filters=64, num_res_blocks=args.residual_blocks, num_upsample=int(np.log2(args.factor)), power=args.scaling_power, res_scale=args.res_scale, use_transposed_conv=args.use_transposed_conv, fully_tconv_upsample=args.fully_transposed_conv, num_final_layer_res=args.num_final_res_blocks).to(device).eval()
    generator.thres = args.threshold
    generator.load_state_dict(torch.load(args.model, map_location=device))
    criterion = torch.nn.L1Loss()
    mse = torch.nn.MSELoss()
    sumpool = SumPool2d(args.factor)

    if args.manual_image is not None:
        global manual_image
        for ii in manual_image:
            print("Proccessing image at location: " + str(ii))
            truth_imgs = dataset.__getitem__(ii)
            truth_hr = truth_imgs["hr"].numpy()
            truth_lr = truth_imgs["lr"].unsqueeze(0)
            gen_hr = generator(truth_lr).detach()
            gen_lr = sumpool(gen_hr).squeeze(0).numpy()
            gen_hr = gen_hr.squeeze(0).numpy()
            truth_lr = truth_lr.squeeze(0).numpy()
            print(truth_hr.shape, truth_lr.shape, gen_hr.shape, gen_lr.shape)
            np.save(args.output+"image_"+str(ii)+"_truth_hr.npy", truth_hr)
            np.save(args.output+"image_"+str(ii)+"_truth_lr.npy", truth_lr)
            np.save(args.output+"image_"+str(ii)+"_gen_hr.npy", gen_hr)
            np.save(args.output+"image_"+str(ii)+"_gen_lr.npy", gen_lr)
        return

    for i, imgs in enumerate(dataloader):
        # Configure model input
        imgs_lr = imgs["lr"].to(device)
        imgs_hr = imgs["hr"].to(device)

        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr).detach()

        with torch.no_grad():
            gen_lr = sumpool(gen_hr).detach()
            gen_nnz = gen_hr[gen_hr > 0].view(-1)
            en_loss = 0
            if len(gen_nnz) > 0:
                real_nnz = imgs_hr[imgs_hr > 0].view(-1)
                e_min = torch.min(torch.cat((gen_nnz, real_nnz), 0)).item()
                e_max = torch.max(torch.cat((gen_nnz, real_nnz), 0)).item()
                gen_hist = torch.histc(gen_nnz, 10, min=e_min, max=e_max).float()
                real_hist = torch.histc(real_nnz, 10, min=e_min, max=e_max).float()
                en_loss = mse(gen_hist, real_hist)
            print("HR L1Loss: %.3e, LR L1Loss: %.3e, Energy distribution loss %.3e" % (criterion(gen_hr, imgs_hr).item(), criterion(gen_lr, imgs_lr).item(), en_loss))

        show(imgs_lr, imgs_hr, gen_hr, gen_lr)
Пример #2
0
    def _set_model(self, device, hr_shape):
        # Initialize generator and discriminator
        self.generator = GeneratorRRDB(
            opt.channels, filters=64,
            num_res_blocks=opt.residual_blocks).to(device)
        self.discriminator = Discriminator(input_shape=(opt.channels,
                                                        *hr_shape)).to(device)
        self.feature_extractor = FeatureExtractor().to(device)

        # Set feature extractor to inference mode
        self.feature_extractor.eval()

        # Losses
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
        self.criterion_content = torch.nn.L1Loss().to(device)
        self.criterion_pixel = torch.nn.L1Loss().to(device)
Пример #3
0
def SuperResolution(f_name, ori):
    pth = "./generator.pth"
    channels = 3
    residual_blocks = 23
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

    # Define model and load model checkpoint
    generator = GeneratorRRDB(channels,
                              filters=64,
                              num_res_blocks=residual_blocks).to(device)
    generator.load_state_dict(torch.load(pth))
    generator.eval()

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    # Prepare input

    image_tensor = Variable(transform(ori)).to(device).unsqueeze(0)

    # Upsample image

    with torch.no_grad():
        sr_image = denormalize(generator(image_tensor)).cpu()

    # Save image
    path = os.path.join("./data/sr_img/", f_name)
    oripath = os.path.join("./data/preprocessed_img/", f_name)
    save_image(sr_image, path)
    result = OCR(path)
    global ocr_result
    ocr_result = result
    mse, psnr = PSNR(oripath, path)
    save_result(f_name, result, mse, psnr)
Пример #4
0
def upsample_empty(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    generator = GeneratorRRDB(1, filters=64, num_res_blocks=args.residual_blocks, num_upsample=int(np.log2(args.factor)), power=args.scaling_power, res_scale=args.res_scale, use_transposed_conv=args.use_transposed_conv, fully_tconv_upsample=args.fully_transposed_conv, num_final_layer_res=args.num_final_res_blocks).to(device).eval()
    generator.thres = args.threshold
    generator.load_state_dict(torch.load(args.model, map_location=device))
    sumpool = SumPool2d(args.factor)

    empty_hr = torch.zeros([1,1,*args.hw])
    empty_lr = sumpool(empty_hr)
    noise = torch.abs(torch.randn(empty_hr.shape))
    noise = noise / (torch.max(noise).item())
    indices = np.random.choice(np.arange(noise.numpy().flatten().size), replace=False, size=int(noise.numpy().flatten().size)-150) #choose indices randomly
    noise[np.unravel_index(indices, noise.shape)] = 0 #and set them to zero
    noise_hr=empty_hr+noise
    noise_lr = sumpool(noise_hr) 

    empty_sr = generator(empty_lr).detach()
    noise_sr = generator(noise_lr).detach()
    print(empty_hr.shape,empty_lr.shape,empty_sr.shape) #delete
    nnz = len([val for val in empty_sr.numpy().squeeze().flatten() if val > args.threshold])
    noisennz = len([val for val in noise_sr.numpy().squeeze().flatten() if val > args.threshold])
    hrnoisennz = len([val for val in noise_hr.numpy().squeeze().flatten() if val > args.threshold])
    print("upsampled empty picture nnz: {}".format(nnz))
    print("upsampled soft noise picture nnz: {}".format(noisennz))
    print("hr soft noise picture nnz: {}".format(hrnoisennz))
    global colors, vmax
    plt.figure()
    plt.subplot(221)
    plt.title("empty hr image")
    plt.imshow(toArray(empty_hr).squeeze(), cmap='gray', vmax=vmax)
    plt.subplot(222)
    plt.title("empty sr image")
    plt.imshow(toArray(empty_sr).squeeze(), cmap='gray', vmax=vmax)
    plt.subplot(223)
    plt.title("soft noise hr image")
    plt.imshow(toArray(noise_hr).squeeze(), cmap='gray', vmax=vmax)
    plt.subplot(224)
    plt.title("soft noise sr image")
    plt.imshow(toArray(noise_sr).squeeze(), cmap='gray', vmax=vmax)
    plt.show()
Пример #5
0
 def network_initializers(self, hr_shape, use_LeakyReLU_Mish=False):
     generator = GeneratorRRDB(self.opt.channels,
                               filters=64,
                               num_res_blocks=self.opt.residual_blocks,
                               use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(
                                   self.device, non_blocking=True)
     discriminator = Discriminator(
         input_shape=(self.opt.channels, *hr_shape),
         use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(self.device,
                                                   non_blocking=True)
     feature_extractor = FeatureExtractor().to(self.device,
                                               non_blocking=True)
     # Set feature extractor to inference mode
     feature_extractor.eval()
     return discriminator, feature_extractor, generator
Пример #6
0
                    default=3,
                    help="Number of image channels")
parser.add_argument("--residual_blocks",
                    type=int,
                    default=23,
                    help="Number of residual blocks in G")
opt = parser.parse_args()
print(opt)

os.makedirs("images/outputs", exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model and load model checkpoint
generator = GeneratorRRDB(opt.channels,
                          filters=64,
                          num_res_blocks=opt.residual_blocks).to(device)
generator.load_state_dict(torch.load(opt.checkpoint_model))
generator.eval()

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean, std)])

# Prepare input
image_tensor = Variable(transform(Image.open(
    opt.image_path))).to(device).unsqueeze(0)

# Upsample image
with torch.no_grad():
    sr_image = denormalize(generator(image_tensor)).cpu()
Пример #7
0
# -

os.makedirs(demo_out_dir, exist_ok=True)

weight_path = "/workspace/output/cat_face/weight/generator_3900.pth"

# # data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hr_shape = (opt.hr_height, opt.hr_width)

# +
# Initialize generator and discriminator
generator = GeneratorRRDB(opt.channels,
                          filters=64,
                          num_res_blocks=opt.residual_blocks).to(device)
generator.load_state_dict(torch.load(weight_path))

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
# -

demo_dataloader = DataLoader(
    DemoImageDataset(demo_in_dir),
    batch_size=opt.batch_size,
    shuffle=False,
    num_workers=opt.n_cpu,
)

# # generate hr image
Пример #8
0

# # main

# +
# def main(opt):
# -

opt = Opt()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hr_shape = (opt.hr_height, opt.hr_width)

# Initialize generator and discriminator
generator = GeneratorRRDB(opt.channels,
                          filters=64,
                          num_res_blocks=opt.residual_blocks).to(device)
discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt.lr,
Пример #9
0
class ESRGAN():
    def __init__(self, opt):
        self.opt = opt
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        hr_shape = (self.opt.hr_height, self.opt.hr_width)
        self._set_model(device, hr_shape)

    def _set_model(self, device, hr_shape):
        # Initialize generator and discriminator
        self.generator = GeneratorRRDB(
            opt.channels, filters=64,
            num_res_blocks=opt.residual_blocks).to(device)
        self.discriminator = Discriminator(input_shape=(opt.channels,
                                                        *hr_shape)).to(device)
        self.feature_extractor = FeatureExtractor().to(device)

        # Set feature extractor to inference mode
        self.feature_extractor.eval()

        # Losses
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
        self.criterion_content = torch.nn.L1Loss().to(device)
        self.criterion_pixel = torch.nn.L1Loss().to(device)

    def _set_param(self):
        for key, value in vars(opt).items():
            mlflow.log_param(key, value)

    def _load_weigth(self):
        if opt.epoch != 0:
            # Load pretrained models
            load_g_weight_path = osp.join(weight_save_dir,
                                          "generator_%d.pth" % opt.epoch)
            load_d_weight_path = osp.join(weight_save_dir,
                                          "discriminator_%d.pth" % opt.epoch)

            self.generator.load_state_dict(torch.load(load_g_weight_path))
            self.discriminator.load_state_dict(torch.load(load_d_weight_path))

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))

    # ----------
    #  Training
    # ----------
    def train(self, dataloader, opt):
        for epoch in range(opt.epoch + 1, opt.n_epochs + 1):
            for batch_num, imgs in enumerate(dataloader):
                Tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
                ) else torch.Tensor
                batches_done = (epoch - 1) * len(dataloader) + batch_num

                # Configure model input
                imgs_lr = Variable(imgs["lr"].type(Tensor))
                imgs_hr = Variable(imgs["hr"].type(Tensor))

                # Adversarial ground truths
                valid = Variable(Tensor(
                    np.ones((imgs_lr.size(0), *discriminator.output_shape))),
                                 requires_grad=False)
                fake = Variable(Tensor(
                    np.zeros((imgs_lr.size(0), *discriminator.output_shape))),
                                requires_grad=False)

                # ------------------
                #  Train Generators
                # ------------------

                optimizer_G.zero_grad()

                # Generate a high resolution image from low resolution input
                gen_hr = generator(imgs_lr)

                # Measure pixel-wise loss against ground truth
                loss_pixel = criterion_pixel(gen_hr, imgs_hr)

                # Warm-up (pixel-wise loss only)
                if batches_done <= opt.warmup_batches:
                    loss_pixel.backward()
                    optimizer_G.step()
                    log_info = "[Epoch {}/{}] [Batch {}/{}] [G pixel: {}]".format(
                        epoch, opt.n_epochs, batch_num, len(dataloader),
                        loss_pixel.item())

                    sys.stdout.write("\r{}".format(log_info))
                    sys.stdout.flush()

                    mlflow.log_metric('train_{}'.format('loss_pixel'),
                                      loss_pixel.item(),
                                      step=batches_done)
                else:
                    # Extract validity predictions from discriminator
                    pred_real = discriminator(imgs_hr).detach()
                    pred_fake = discriminator(gen_hr)

                    # Adversarial loss (relativistic average GAN)
                    loss_GAN = criterion_GAN(
                        pred_fake - pred_real.mean(0, keepdim=True), valid)

                    # Content loss
                    gen_features = feature_extractor(gen_hr)
                    real_features = feature_extractor(imgs_hr).detach()
                    loss_content = criterion_content(gen_features,
                                                     real_features)

                    # Total generator loss
                    loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel

                    loss_G.backward()
                    optimizer_G.step()

                    # ---------------------
                    #  Train Discriminator
                    # ---------------------

                    optimizer_D.zero_grad()

                    pred_real = discriminator(imgs_hr)
                    pred_fake = discriminator(gen_hr.detach())

                    # Adversarial loss for real and fake images (relativistic average GAN)
                    loss_real = criterion_GAN(
                        pred_real - pred_fake.mean(0, keepdim=True), valid)
                    loss_fake = criterion_GAN(
                        pred_fake - pred_real.mean(0, keepdim=True), fake)

                    # Total loss
                    loss_D = (loss_real + loss_fake) / 2

                    loss_D.backward()
                    optimizer_D.step()

                    # --------------
                    #  Log Progress
                    # --------------

                    log_info = "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}, content: {}, adv: {}, pixel: {}]".format(
                        epoch,
                        opt.n_epochs,
                        batch_num,
                        len(dataloader),
                        loss_D.item(),
                        loss_G.item(),
                        loss_content.item(),
                        loss_GAN.item(),
                        loss_pixel.item(),
                    )

                    if batch_num == 1:
                        sys.stdout.write("\n{}".format(log_info))
                    else:
                        sys.stdout.write("\r{}".format(log_info))

                    sys.stdout.flush()

                    # import pdb; pdb.set_trace()

                    if batches_done % opt.sample_interval == 0:
                        # Save image grid with upsampled inputs and ESRGAN outputs
                        imgs_lr = nn.functional.interpolate(imgs_lr,
                                                            scale_factor=4)
                        img_grid = denormalize(torch.cat((imgs_lr, gen_hr),
                                                         -1))

                        image_batch_save_dir = osp.join(
                            image_train_save_dir, '{:07}'.format(batches_done))
                        os.makedirs(osp.join(image_batch_save_dir, "hr_image"),
                                    exist_ok=True)
                        save_image(img_grid,
                                   osp.join(image_batch_save_dir, "hr_image",
                                            "%d.png" % batches_done),
                                   nrow=1,
                                   normalize=False)

                    if batches_done % opt.checkpoint_interval == 0:
                        # Save model checkpoints
                        torch.save(
                            generator.state_dict(),
                            osp.join(weight_save_dir,
                                     "generator_%d.pth" % epoch))
                        torch.save(
                            discriminator.state_dict(),
                            osp.join(weight_save_dir,
                                     "discriminator_%d.pth" % epoch))

                    mlflow.log_metric('train_{}'.format('loss_D'),
                                      loss_D.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_G'),
                                      loss_G.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_content'),
                                      loss_content.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_GAN'),
                                      loss_GAN.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_pixel'),
                                      loss_pixel.item(),
                                      step=batches_done)