def main_worker(gpu, args):
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for testing.")

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    # Set eval mode.
    model.eval()

    cudnn.benchmark = True

    # Get image filename.
    filename = os.path.basename(args.lr)

    # Read all pictures.
    lr = Image.open(args.lr)
    bicubic = transforms.Resize(
        (lr.size[1] * args.upscale_factor, lr.size[0] * args.upscale_factor),
        InterpolationMode.BICUBIC)(lr)
    lr = process_image(lr, args.gpu)
    bicubic = process_image(bicubic, args.gpu)

    with torch.no_grad():
        sr = model(lr)

    if args.hr:
        hr = process_image(Image.open(args.hr), args.gpu)
        vutils.save_image(hr, os.path.join("tests", f"hr_{filename}"))
        images = torch.cat([bicubic, sr, hr], dim=-1)

        value = iqa(sr, hr, args.gpu)
        print(f"Performance avg results:\n")
        print(f"indicator Score\n")
        print(f"--------- -----\n")
        print(f"MSE       {value[0]:6.4f}\n"
              f"RMSE      {value[1]:6.4f}\n"
              f"PSNR      {value[2]:6.2f}\n"
              f"SSIM      {value[3]:6.4f}\n"
              f"LPIPS     {value[4]:6.4f}\n"
              f"GMSD      {value[5]:6.4f}\n")
    else:
        images = torch.cat([bicubic, sr], dim=-1)

    vutils.save_image(lr, os.path.join("tests", f"lr_{filename}"))
    vutils.save_image(bicubic, os.path.join("tests", f"bicubic_{filename}"))
    vutils.save_image(sr, os.path.join("tests", f"sr_{filename}"))
    vutils.save_image(images,
                      os.path.join("tests", f"compare_{filename}"),
                      padding=10)
def main_worker(gpu, args):
    global best_mse_value, best_rmse_value, best_psnr_value, best_ssim_value, best_lpips_value, best_gmsd_value
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for training.")

    cudnn.benchmark = True

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")

    # Read all pictures.
    lr = process_image(Image.open(args.lr), args.gpu)
    hr = process_image(Image.open(args.hr), args.gpu)

    model_paths = glob(
        os.path.join(f"{args.model_dir}", "Generator_epoch*.pth"))
    best_model = model_paths[0]

    for model_path in model_paths:
        print(f"Process `{model_path}`")
        value = inference(lr, hr, model, model_path, gpu)

        is_best = value[2] > best_psnr_value

        if is_best:
            best_model = os.path.basename(model_path)
            best_mse_value = value[0]
            best_rmse_value = value[1]
            best_psnr_value = value[2]
            best_ssim_value = value[3]
            best_lpips_value = value[4]
            best_gmsd_value = value[5]

    print("\n##################################################")
    print(f"Best model: `{best_model}`.")
    print(f"indicator Score")
    print(f"--------- -----")
    print(f"MSE       {best_mse_value:6.4f}"
          f"RMSE      {best_rmse_value:6.2f}"
          f"PSNR      {best_psnr_value:6.2f}\n"
          f"SSIM      {best_ssim_value:6.4f}\n"
          f"LPIPS     {best_lpips_value:6.4f}\n"
          f"GMSD      {best_gmsd_value:6.4f}")
    print(f"--------- -----")
    print("##################################################\n")
def main(args) -> None:
    if args.seed is not None:
        # In order to make the model repeatable, the first step is to set random seeds, and the second step is to set
        # convolution algorithm.
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        # for the current configuration, so as to optimize the operation efficiency.
        cudnn.benchmark = True
        # Ensure that every time the same input returns the same result.
        cudnn.deterministic = True

    # Build a super-resolution model, if model path is defined, the specified model weight will be loaded.
    model = configure(args)
    # Switch model to eval mode.
    model.eval()

    # Create an image that conforms to the normal distribution.
    data = torch.randn([1, 3, args.image_size, args.image_size],
                       requires_grad=False)

    # If there is a GPU, the data will be loaded into the GPU memory.
    if args.gpu is not None:
        data = data.cuda(args.gpu, non_blocking=True)

    # Needs to reconstruct the low resolution image without the gradient information of the reconstructed image.
    with torch.no_grad():
        start = time.time()
        _ = model(data)
        # Waits for all kernels in all streams on a CUDA device to complete.
        torch.cuda.synchronize()
        print(f"Time:{(time.time() - start) * 1000:.2f}ms.")

    # Context manager that manages autograd profiler state and holds a summary of results.
    with torch.autograd.profiler.profile(enabled=True,
                                         use_cuda=args.gpu) as profile:
        _ = model(data)
    print(profile.table())
    # Open Chrome browser and enter in the address bar `chrome://tracing`
    profile.export_chrome_trace("profile.json")
Example #4
0
def main_worker(ngpus_per_node, args):
    global best_psnr, best_ssim, fixed_lr

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for training.")

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + args.gpu
        dist.init_process_group(args.dist_backend,
                                args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    generator = configure(args)
    discriminator = discriminator_for_vgg(args.image_size)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            discriminator.cuda(args.gpu)
            generator.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            discriminator = nn.parallel.DistributedDataParallel(
                discriminator, device_ids=[args.gpu])
            generator = nn.parallel.DistributedDataParallel(
                generator, device_ids=[args.gpu])
        else:
            discriminator.cuda()
            generator.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            discriminator = nn.parallel.DistributedDataParallel(discriminator)
            generator = nn.parallel.DistributedDataParallel(generator)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        discriminator = discriminator.cuda(args.gpu)
        generator = generator.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith("alexnet") or args.arch.startswith("vgg"):
            discriminator.features = torch.nn.DataParallel(
                discriminator.features)
            generator.features = torch.nn.DataParallel(generator.features)
            discriminator.cuda()
            generator.cuda()
        else:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()

    # Loss = 10 * pixel loss + content loss + 0.005 * adversarial loss
    pixel_criterion = nn.L1Loss().cuda(args.gpu)
    content_criterion = ContentLoss().cuda(args.gpu)
    adversarial_criterion = nn.BCEWithLogitsLoss().cuda(args.gpu)

    if args.gpu is not None:
        fixed_lr = fixed_lr.cuda(args.gpu)

    # All optimizer function and scheduler function.
    psnr_optimizer = torch.optim.Adam(generator.parameters(),
                                      lr=args.psnr_lr,
                                      betas=(0.9, 0.99))
    psnr_epoch_indices = math.floor(args.psnr_epochs // 4)
    psnr_scheduler = torch.optim.lr_scheduler.StepLR(psnr_optimizer,
                                                     psnr_epoch_indices, 0.5)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.gan_lr,
                                               betas=(0.9, 0.99))
    generator_optimizer = torch.optim.Adam(generator.parameters(),
                                           lr=args.gan_lr,
                                           betas=(0.9, 0.99))
    interval_epoch = math.ceil(args.gan_epochs // 8)
    gan_epoch_indices = [
        interval_epoch, interval_epoch * 2, interval_epoch * 4,
        interval_epoch * 6
    ]
    discriminator_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        discriminator_optimizer, gan_epoch_indices, 0.5)
    generator_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        generator_optimizer, gan_epoch_indices, 0.5)

    # Selection of appropriate treatment equipment.
    train_dataset = BaseTrainDataset(os.path.join(args.data, "train"),
                                     args.image_size, args.upscale_factor)
    test_dataset = BaseTestDataset(os.path.join(args.data, "test"),
                                   args.image_size, args.upscale_factor)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        pin_memory=True,
        sampler=train_sampler,
        num_workers=args.workers)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  num_workers=args.workers)

    # Load pre training model.
    if args.netD != "":
        discriminator.load_state_dict(torch.load(args.netD))
    if args.netG != "":
        generator.load_state_dict(torch.load(args.netG))

        # The mixed precision training is used in PSNR-oral.
        scaler = amp.GradScaler()
        logger.info("Turn on mixed precision training.")

        # Create a SummaryWriter at the beginning of training.
        psnr_writer = SummaryWriter(f"runs/{args.arch}_psnr_logs")
        gan_writer = SummaryWriter(f"runs/{args.arch}_gan_logs")

        for epoch in range(args.start_psnr_epoch, args.psnr_epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)

            # Train for one epoch for PSNR-oral.
            train_psnr(train_dataloader, generator, pixel_criterion,
                       psnr_optimizer, epoch, scaler, psnr_writer, args)
            # Update GAN-oral optimizer learning rate.
            psnr_scheduler.step()

            # Evaluate on test dataset.
            psnr, ssim, gmsd = test(test_dataloader, generator, args.gpu)
            psnr_writer.add_scalar("PSNR_Test/PSNR", psnr, epoch)
            psnr_writer.add_scalar("PSNR_Test/SSIM", ssim, epoch)
            psnr_writer.add_scalar("PSNR_Test/GMSD", gmsd, epoch)

            # Check whether the evaluation index of the current model is the highest.
            is_best = psnr > best_psnr
            best_psnr = max(psnr, best_psnr)
            # Save model weights for every epoch.
            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                torch.save(generator.state_dict(),
                           os.path.join("weights", f"PSNR_epoch{epoch}.pth"))
                if is_best:
                    torch.save(generator.state_dict(),
                               os.path.join("weights", f"PSNR-best.pth"))

            # Save the last training model parameters.
        torch.save(generator.state_dict(),
                   os.path.join("weights", f"PSNR-last.pth"))

        for epoch in range(args.start_gan_epoch, args.gan_epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)

            # Train for one epoch for GAN-oral.
            train_gan(train_dataloader, discriminator, discriminator_optimizer,
                      generator, generator_optimizer, pixel_criterion,
                      content_criterion, adversarial_criterion, epoch, scaler,
                      gan_writer, args)
            # Update GAN-oral optimizer learning rate.
            discriminator_scheduler.step()
            generator_scheduler.step()

            # Evaluate on test dataset.
            psnr, ssim, gmsd = test(test_dataloader, generator, args.gpu)
            gan_writer.add_scalar("GAN_Test/PSNR", psnr, epoch)
            gan_writer.add_scalar("GAN_Test/SSIM", ssim, epoch)
            gan_writer.add_scalar("GAN_Test/GMSD", gmsd, epoch)

            # Check whether the evaluation index of the current model is the highest.
            is_best = ssim > best_ssim
            best_ssim = max(ssim, best_ssim)
            # Save model weights for every epoch.
            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                torch.save(
                    discriminator.state_dict(),
                    os.path.join("weights", f"Discriminator_epoch{epoch}.pth"))
                torch.save(
                    generator.state_dict(),
                    os.path.join("weights", f"Generator_epoch{epoch}.pth"))
                if is_best:
                    torch.save(generator.state_dict(),
                               os.path.join("weights", f"GAN-best.pth"))

        # Save the last training model parameters.
        torch.save(generator.state_dict(),
                   os.path.join("weights", f"GAN-last.pth"))
def main_worker(gpu, args):
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for testing.")

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    cudnn.benchmark = True

    # Set eval mode.
    model.eval()

    # Get video filename.
    filename = os.path.basename(args.file)

    # Image preprocessing operation
    tensor2pil = transforms.ToPILImage()

    video_capture = cv2.VideoCapture(args.file)
    # Prepare to write the processed image into the video.
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    # Set video size
    size = (int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    sr_size = (size[0] * args.upscale_factor, size[1] * args.upscale_factor)
    pare_size = (sr_size[0] * 2 + 10, sr_size[1] + 10 + sr_size[0] // 5 - 9)
    # Video write loader.
    sr_writer = cv2.VideoWriter(
        os.path.join("videos", f"sr_{args.upscale_factor}x_{filename}"),
        cv2.VideoWriter_fourcc(*"MPEG"), fps, sr_size)
    compare_writer = cv2.VideoWriter(
        os.path.join("videos", f"compare_{args.upscale_factor}x_{filename}"),
        cv2.VideoWriter_fourcc(*"MPEG"), fps, pare_size)

    # read frame.
    with torch.no_grad():
        success, raw_frame = video_capture.read()
        progress_bar = tqdm(
            range(total_frames),
            desc="[processing video and saving/view result videos]")
        for _ in progress_bar:
            if success:
                # Read image to tensor and transfer to the specified device for processing.
                lr = process_image(raw_frame, args.gpu)

                sr = model(lr)

                sr = sr.cpu()
                sr = sr.data[0].numpy()
                sr *= 255.0
                sr = (np.uint8(sr)).transpose((1, 2, 0))
                # save sr video
                sr_writer.write(sr)

                # make compared video and crop shot of left top\right top\center\left bottom\right bottom
                sr = tensor2pil(sr)
                # Five areas are selected as the bottom contrast map.
                crop_sr_images = transforms.FiveCrop(size=sr.width // 5 -
                                                     9)(sr)
                crop_sr_images = [
                    np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(image))
                    for image in crop_sr_images
                ]
                sr = transforms.Pad(padding=(5, 0, 0, 5))(sr)
                # Five areas in the contrast map are selected as the bottom contrast map
                compare_image = transforms.Resize(
                    (sr_size[1], sr_size[0]),
                    interpolation=Mode.BICUBIC)(tensor2pil(raw_frame))
                crop_compare_images = transforms.FiveCrop(
                    size=compare_image.width // 5 - 9)(compare_image)
                crop_compare_images = [
                    np.asarray(transforms.Pad((0, 5, 10, 0))(image))
                    for image in crop_compare_images
                ]
                compare_image = transforms.Pad(padding=(0, 0, 5,
                                                        5))(compare_image)
                # concatenate all the pictures to one single picture
                # 1. Mosaic the left and right images of the video.
                top_image = np.concatenate(
                    (np.asarray(compare_image), np.asarray(sr)), axis=1)
                # 2. Mosaic the bottom left and bottom right images of the video.
                bottom_image = np.concatenate(crop_compare_images +
                                              crop_sr_images,
                                              axis=1)
                bottom_image_height = int(top_image.shape[1] /
                                          bottom_image.shape[1] *
                                          bottom_image.shape[0])
                bottom_image_width = top_image.shape[1]
                # 3. Adjust to the right size.
                bottom_image = np.asarray(
                    transforms.Resize(
                        (bottom_image_height,
                         bottom_image_width))(tensor2pil(bottom_image)))
                # 4. Combine the bottom zone with the upper zone.
                final_image = np.concatenate((top_image, bottom_image))

                # save compare video
                compare_writer.write(final_image)

                if args.view:
                    # display video
                    cv2.imshow("LR video convert SR video ", final_image)
                    if cv2.waitKey(1) & 0xFF == ord("q"):
                        break

                # next frame
                success, raw_frame = video_capture.read()
Example #6
0
def main(args):
    if args.seed is not None:
        # In order to make the model repeatable, the first step is to set random seeds, and the second step is to set
        # convolution algorithm.
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        logger.warning("You have chosen to seed testing. "
                       "This will turn on the CUDNN deterministic setting, "
                       "which can slow down your testing considerably! "
                       "You may see unexpected behavior when restarting "
                       "from checkpoints.")
        # for the current configuration, so as to optimize the operation efficiency.
        cudnn.benchmark = True
        # Ensure that every time the same input returns the same result.
        cudnn.deterministic = True

    # Build a super-resolution model, if model path is defined, the specified model weight will be loaded.
    model = configure(args)
    # Switch model to eval mode.
    model.eval()

    # If the GPU is available, load the model into the GPU memory. This speed.
    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")

    # Get video filename.
    filename = os.path.basename(args.lr)

    # OpenCV video input method open.
    video_capture = cv2.VideoCapture(args.file)
    # Prepare to write the processed image into the video.
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    # Set video window resolution size.
    raw_video_size = (int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
                      int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    sr_video_size = (raw_video_size[0] * args.upscale_factor, raw_video_size[1] * args.upscale_factor)
    compare_video_size = (sr_video_size[0] * 2 + 10, sr_video_size[1] + 10 + sr_video_size[0] // 5 - 9)
    # Video write loader.
    sr_writer_path = os.path.join("videos", f"sr_{args.upscale_factor}x_{filename}")
    compare_writer_path = os.path.join("videos", f"compare_{args.upscale_factor}x_{filename}")
    sr_writer = cv2.VideoWriter(sr_writer_path, cv2.VideoWriter_fourcc(*"MPEG"), fps, sr_video_size)
    compare_writer = cv2.VideoWriter(compare_writer_path, cv2.VideoWriter_fourcc(*"MPEG"), fps, compare_video_size)

    # read video frame.
    with torch.no_grad():
        success, raw_frame = video_capture.read()
        for _ in tqdm(range(total_frames), desc="[processing video and saving/view result videos]"):
            if success:
                # The low resolution image is reconstructed to the super resolution image.
                sr = model(process_image(raw_frame, norm=False, gpu=args.gpu))

                # Convert N*C*H*W image data to H*W*C image data.
                sr = sr.cpu()
                sr = sr.data[0].numpy()
                sr *= 255.0
                sr = (np.uint8(sr)).transpose((1, 2, 0))
                # Writer sr video to SR video file.
                sr_writer.write(sr)

                # Make compared video and crop shot of left top\right top\center\left bottom\right bottom.
                sr = transforms.ToPILImage()(sr)
                # Five areas are selected as the bottom contrast map.
                crop_sr_images = transforms.FiveCrop(sr.width // 5 - 9)(sr)
                crop_sr_images = [np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(image)) for image in crop_sr_images]
                sr = transforms.Pad(padding=(5, 0, 0, 5))(sr)
                # Five areas in the contrast map are selected as the bottom contrast map
                compare_image_size = (sr_video_size[1], sr_video_size[0])
                compare_image = transforms.Resize(compare_image_size, interpolation=Mode.BICUBIC)(raw_frame)
                compare_image = transforms.ToPILImage()(compare_image)
                crop_compare_images = transforms.FiveCrop(compare_image.width // 5 - 9)(compare_image)
                crop_compare_images = [np.asarray(transforms.Pad((0, 5, 10, 0))(image)) for image in
                                       crop_compare_images]
                compare_image = transforms.Pad(padding=(0, 0, 5, 5))(compare_image)
                # Concatenate all the pictures to one single picture
                # 1. Mosaic the left and right images of the video.
                top_image = np.concatenate((np.asarray(compare_image), np.asarray(sr)), axis=1)
                # 2. Mosaic the bottom left and bottom right images of the video.
                bottom_image = np.concatenate(crop_compare_images + crop_sr_images, axis=1)
                bottom_image_height = int(top_image.shape[1] / bottom_image.shape[1] * bottom_image.shape[0])
                bottom_image_width = top_image.shape[1]
                # 3. Adjust to the right size.
                bottom_image_size = (bottom_image_height, bottom_image_width)
                bottom_image = np.asarray(transforms.Resize(bottom_image_size)(transforms.ToPILImage()(bottom_image)))
                # 4. Combine the bottom zone with the upper zone.
                images = np.concatenate((top_image, bottom_image))

                # Writer compare video to compare video file.
                compare_writer.write(images)

                # Display compare video.
                if args.view:
                    cv2.imshow("LR video convert SR video ", images)
                    if cv2.waitKey(1) & 0xFF == ord("q"):
                        break

                # Read next frame.
                success, raw_frame = video_capture.read()
Example #7
0
def main(args):
    if args.seed is not None:
        # In order to make the model repeatable, the first step is to set random seeds, and the second step is to set
        # convolution algorithm.
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        logger.warning("You have chosen to seed testing. "
                       "This will turn on the CUDNN deterministic setting, "
                       "which can slow down your testing considerably! "
                       "You may see unexpected behavior when restarting "
                       "from checkpoints.")
        # for the current configuration, so as to optimize the operation efficiency.
        cudnn.benchmark = True
        # Ensure that every time the same input returns the same result.
        cudnn.deterministic = True

    # Build a super-resolution model, if model path is defined, the specified model weight will be loaded.
    model = configure(args)
    # If special choice model path.
    if args.model_path is not None:
        logger.info(
            f"You loaded the specified weight. Load weights from `{os.path.abspath(args.model_path)}`."
        )
        model.load_state_dict(
            torch.load(args.model_path, map_location=torch.device("cpu")))
    # Switch model to eval mode.
    model.eval()

    # If the GPU is available, load the model into the GPU memory. This speed.
    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # Setting this flag allows the built-in auto tuner of cudnn to automatically find the most efficient
        # algorithm suitable for the current configuration, so as to optimize the operation efficiency.
        cudnn.benchmark = True
        # Ensure that every time the same input returns the same result.
        cudnn.deterministic = True

    # Get image filename.
    filename = os.path.basename(args.lr)

    # Read the low resolution image and enlarge the low resolution image with bicubic method.
    # The purpose of bicubic method is to compare the reconstruction results.
    lr = Image.open(args.lr)
    bicubic_image_size = (lr.size[1] * args.upscale_factor,
                          lr.size[0] * args.upscale_factor)
    bicubic = transforms.Resize(bicubic_image_size,
                                InterpolationMode.BICUBIC)(lr)
    lr = process_image(lr, norm=False, gpu=args.gpu)
    bicubic = process_image(bicubic, norm=True, gpu=args.gpu)

    # Needs to reconstruct the low resolution image without the gradient information of the reconstructed image.
    with torch.no_grad():
        sr = model(lr)

    # If there is a reference image, a series of evaluation indexes will be output.
    if args.hr:
        hr = process_image(Image.open(args.hr), norm=True, gpu=args.gpu)
        vutils.save_image(hr,
                          os.path.join("tests", f"hr_{filename}"),
                          normalize=True)
        # Merge three images into one line for visualization.
        images = torch.cat([bicubic, sr, hr], dim=-1)

        # The reconstructed image and the reference image are evaluated once.
        value = iqa(sr, hr, args.gpu)
        print(f"Performance avg results:\n")
        print(f"Indicator score\n")
        print(f"--------- -----\n")
        print(f"MSE       {value[0]:6.4f}\n"
              f"RMSE      {value[1]:6.4f}\n"
              f"PSNR      {value[2]:6.2f}\n"
              f"SSIM      {value[3]:6.4f}\n"
              f"GMSD      {value[4]:6.4f}\n")
    else:
        # Merge two images into one line for visualization.
        images = torch.cat([bicubic, sr], dim=-1)

    # Save a series of reconstruction results.
    vutils.save_image(lr, os.path.join("tests", f"lr_{filename}"))
    vutils.save_image(bicubic,
                      os.path.join("tests", f"bicubic_{filename}"),
                      normalize=True)
    vutils.save_image(sr,
                      os.path.join("tests", f"sr_{filename}"),
                      normalize=True)
    vutils.save_image(images,
                      os.path.join("tests", f"compare_{filename}"),
                      padding=10,
                      normalize=True)
Example #8
0
def main_worker(gpu, ngpus_per_node, args):
    global best_psnr
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for training.")

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
    # create model
    generator = configure(args)
    discriminator = discriminator_for_vgg(image_size=args.image_size)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            discriminator.cuda(args.gpu)
            generator.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            discriminator = nn.parallel.DistributedDataParallel(module=discriminator, device_ids=[args.gpu])
            generator = nn.parallel.DistributedDataParallel(module=generator, device_ids=[args.gpu])
        else:
            discriminator.cuda()
            generator.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            discriminator = nn.parallel.DistributedDataParallel(discriminator)
            generator = nn.parallel.DistributedDataParallel(generator)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        discriminator = discriminator.cuda(args.gpu)
        generator = generator.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith("alexnet") or args.arch.startswith("vgg"):
            discriminator.features = torch.nn.DataParallel(discriminator.features)
            generator.features = torch.nn.DataParallel(generator.features)
            discriminator.cuda()
            generator.cuda()
        else:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()

    # Loss = 10 * pixel loss + content loss + 0.005 * adversarial loss
    pixel_criterion = nn.L1Loss().cuda(args.gpu)
    content_criterion = VGGLoss().cuda(args.gpu)
    adversarial_criterion = nn.BCEWithLogitsLoss().cuda(args.gpu)
    logger.info(f"Losses function information:\n"
                f"\tPixel:       L1Loss\n"
                f"\tContent:     VGG19_35th\n"
                f"\tAdversarial: BCEWithLogitsLoss")

    # All optimizer function and scheduler function.
    psnr_optimizer = torch.optim.Adam(generator.parameters(), lr=args.psnr_lr, betas=(0.9, 0.99))
    psnr_epoch_indices = math.floor(args.psnr_epochs // 4)
    psnr_scheduler = torch.optim.lr_scheduler.StepLR(psnr_optimizer, step_size=psnr_epoch_indices, gamma=0.5)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.gan_lr, betas=(0.9, 0.99))
    generator_optimizer = torch.optim.Adam(generator.parameters(), args.gan_lr, (0.9, 0.99))
    interval_epoch = math.ceil(args.gan_epochs // 8)
    gan_epoch_indices = [interval_epoch, interval_epoch * 2, interval_epoch * 4, interval_epoch * 6]
    discriminator_scheduler = torch.optim.lr_scheduler.MultiStepLR(discriminator_optimizer, milestones=gan_epoch_indices, gamma=0.5)
    generator_scheduler = torch.optim.lr_scheduler.MultiStepLR(generator_optimizer, milestones=gan_epoch_indices, gamma=0.5)
    logger.info(f"Optimizer information:\n"
                f"\tPSNR learning rate:          {args.psnr_lr}\n"
                f"\tDiscriminator learning rate: {args.gan_lr}\n"
                f"\tGenerator learning rate:     {args.gan_lr}\n"
                f"\tPSNR optimizer:              Adam, [betas=(0.9,0.99)]\n"
                f"\tDiscriminator optimizer:     Adam, [betas=(0.9,0.99)]\n"
                f"\tGenerator optimizer:         Adam, [betas=(0.9,0.99)]\n"
                f"\tPSNR scheduler:              StepLR, [step_size=psnr_epoch_indices, gamma=0.5]\n"
                f"\tDiscriminator scheduler:     MultiStepLR, [milestones=epoch_indices, gamma=0.5]\n"
                f"\tGenerator scheduler:         MultiStepLR, [milestones=epoch_indices, gamma=0.5]")

    logger.info("Load training dataset")
    # Selection of appropriate treatment equipment.
    train_dataset = BaseTrainDataset(os.path.join(args.data, "train"), args.image_size, args.upscale_factor)
    test_dataset = BaseTestDataset(os.path.join(args.data, "test"), args.image_size, args.upscale_factor)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=(train_sampler is None),
                                                   pin_memory=True,
                                                   sampler=train_sampler,
                                                   num_workers=args.workers)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  num_workers=args.workers)

    logger.info(f"Dataset information:\n"
                f"\tTrain Path:              {os.getcwd()}/{args.data}/train\n"
                f"\tTest Path:               {os.getcwd()}/{args.data}/test\n"
                f"\tNumber of train samples: {len(train_dataset)}\n"
                f"\tNumber of test samples:  {len(test_dataset)}\n"
                f"\tNumber of train batches: {len(train_dataloader)}\n"
                f"\tNumber of test batches:  {len(test_dataloader)}\n"
                f"\tShuffle of train:        True\n"
                f"\tShuffle of test:         False\n"
                f"\tSampler of train:        {bool(train_sampler)}\n"
                f"\tSampler of test:         None\n"
                f"\tWorkers of train:        {args.workers}\n"
                f"\tWorkers of test:         {args.workers}")

    # optionally resume from a checkpoint
    if args.resume_psnr:
        if os.path.isfile(args.resume_psnr):
            logger.info(f"Loading checkpoint '{args.resume_psnr}'.")
            if args.gpu is None:
                checkpoint = torch.load(args.resume_psnr)
            else:
                # Map model to be loaded to specified single gpu.
                checkpoint = torch.load(args.resume_psnr, map_location=f"cuda:{args.gpu}")
            args.start_psnr_epoch = checkpoint["epoch"]
            best_psnr = checkpoint["best_psnr"]
            if args.gpu is not None:
                # best_psnr may be from a checkpoint from a different GPU
                best_psnr = best_psnr.to(args.gpu)
            generator.load_state_dict(checkpoint["state_dict"])
            psnr_optimizer.load_state_dict(checkpoint["optimizer"])
            logger.info(f"Loaded checkpoint '{args.resume_psnr}' (epoch {checkpoint['epoch']}).")
        else:
            logger.info(f"No checkpoint found at '{args.resume_psnr}'.")

    if args.resume_d or args.resume_g:
        if os.path.isfile(args.resume_d) or os.path.isfile(args.resume_g):
            logger.info(f"Loading checkpoint '{args.resume_d}'.")
            logger.info(f"Loading checkpoint '{args.resume_g}'.")
            if args.gpu is None:
                checkpoint_d = torch.load(args.resume_d)
                checkpoint_g = torch.load(args.resume_g)
            else:
                # Map model to be loaded to specified single gpu.
                checkpoint_d = torch.load(args.resume_d, map_location=f"cuda:{args.gpu}")
                checkpoint_g = torch.load(args.resume_g, map_location=f"cuda:{args.gpu}")
            args.start_gan_epoch = checkpoint_g["epoch"]
            best_psnr = checkpoint_g["best_psnr"]
            if args.gpu is not None:
                # best_psnr may be from a checkpoint from a different GPU
                best_psnr = best_psnr.to(args.gpu)
            discriminator.load_state_dict(checkpoint_d["state_dict"])
            discriminator_optimizer.load_state_dict(checkpoint_d["optimizer"])
            generator.load_state_dict(checkpoint_g["state_dict"])
            generator_optimizer.load_state_dict(checkpoint_g["optimizer"])
            logger.info(f"Loaded checkpoint '{args.resume_d}' (epoch {checkpoint_d['epoch']}).")
            logger.info(f"Loaded checkpoint '{args.resume_g}' (epoch {checkpoint_g['epoch']}).")
        else:
            logger.info(f"No checkpoint found at '{args.resume_d}' or '{args.resume_g}'.")

    cudnn.benchmark = True

    # The mixed precision training is used in PSNR-oral.
    scaler = amp.GradScaler()
    logger.info("Turn on mixed precision training.")

    # Create a SummaryWriter at the beginning of training.
    psnr_writer = SummaryWriter(f"runs/{args.arch}_psnr_logs")
    gan_writer = SummaryWriter(f"runs/{args.arch}_gan_logs")

    logger.info(f"Train information:\n"
                f"\tPSNR-oral epochs: {args.psnr_epochs}\n"
                f"\tGAN-oral epochs:  {args.gan_epochs}")

    for epoch in range(args.start_psnr_epoch, args.psnr_epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_psnr(dataloader=train_dataloader,
                   model=generator,
                   criterion=pixel_criterion,
                   optimizer=psnr_optimizer,
                   epoch=epoch,
                   scaler=scaler,
                   writer=psnr_writer,
                   args=args)

        psnr_scheduler.step()

        # Test for every epoch.
        psnr, ssim, lpips, gmsd = test(dataloader=test_dataloader, model=generator, gpu=args.gpu)
        gan_writer.add_scalar("Test/PSNR", psnr, epoch + 1)
        gan_writer.add_scalar("Test/SSIM", ssim, epoch + 1)
        gan_writer.add_scalar("Test/LPIPS", lpips, epoch + 1)
        gan_writer.add_scalar("Test/GMSD", gmsd, epoch + 1)

        is_best = psnr > best_psnr
        best_psnr = max(psnr, best_psnr)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
            torch.save({"epoch": epoch + 1,
                        "arch": args.arch,
                        "best_psnr": best_psnr,
                        "state_dict": generator.state_dict(),
                        "optimizer": psnr_optimizer.state_dict(),
                        }, os.path.join("weights", f"PSNR_epoch{epoch}.pth"))
            if is_best:
                torch.save(generator.state_dict(), os.path.join("weights", f"PSNR.pth"))

    # Load best model weight.
    best_psnr = 0.0
    generator.load_state_dict(torch.load(os.path.join("weights", f"PSNR.pth"), map_location=f"cuda:{args.gpu}"))

    for epoch in range(args.start_gan_epoch, args.gan_epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_gan(dataloader=train_dataloader,
                  discriminator=discriminator,
                  discriminator_optimizer=discriminator_optimizer,
                  generator=generator,
                  generator_optimizer=generator_optimizer,
                  pixel_criterion=pixel_criterion,
                  content_criterion=content_criterion,
                  adversarial_criterion=adversarial_criterion,
                  epoch=epoch,
                  scaler=scaler,
                  writer=gan_writer,
                  args=args)

        discriminator_scheduler.step()
        generator_scheduler.step()

        # Test for every epoch.
        psnr, ssim, lpips, gmsd = test(dataloader=test_dataloader, model=generator, gpu=args.gpu)
        gan_writer.add_scalar("Test/PSNR", psnr, epoch + 1)
        gan_writer.add_scalar("Test/SSIM", ssim, epoch + 1)
        gan_writer.add_scalar("Test/LPIPS", lpips, epoch + 1)
        gan_writer.add_scalar("Test/GMSD", gmsd, epoch + 1)

        is_best = psnr > best_psnr
        best_psnr = max(psnr, best_psnr)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
            torch.save({"epoch": epoch + 1,
                        "arch": "vgg",
                        "state_dict": discriminator.state_dict(),
                        "optimizer": discriminator_optimizer.state_dict()
                        }, os.path.join("weights", f"Discriminator_epoch{epoch}.pth"))
            torch.save({"epoch": epoch + 1,
                        "arch": args.arch,
                        "best_psnr": best_psnr,
                        "state_dict": generator.state_dict(),
                        "optimizer": generator_optimizer.state_dict()
                        }, os.path.join("weights", f"Generator_epoch{epoch}.pth"))
            if is_best:
                torch.save(generator.state_dict(), os.path.join("weights", f"GAN.pth"))
def main_worker(gpu, args):
    global total_mse_value, total_rmse_value, total_psnr_value, total_ssim_value, total_lpips_value, total_gmsd_value
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for testing.")

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    logger.info("Load testing dataset.")
    # Selection of appropriate treatment equipment.
    dataset = BaseTestDataset(os.path.join(args.data, "test"), args.image_size,
                              args.upscale_factor)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             pin_memory=True,
                                             num_workers=args.workers)
    logger.info(f"Dataset information:\n"
                f"\tPath:              {os.getcwd()}/{args.data}/test\n"
                f"\tNumber of samples: {len(dataset)}\n"
                f"\tNumber of batches: {len(dataloader)}\n"
                f"\tShuffle:           False\n"
                f"\tSampler:           None\n"
                f"\tWorkers:           {args.workers}")

    cudnn.benchmark = True

    # Set eval mode.
    model.eval()

    with torch.no_grad():
        # Start evaluate model performance.
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, (lr, bicubic, hr) in progress_bar:
            # Move data to special device.
            if args.gpu is not None:
                lr = lr.cuda(args.gpu, non_blocking=True)
                bicubic = bicubic.cuda(args.gpu, non_blocking=True)
                hr = hr.cuda(args.gpu, non_blocking=True)

            sr = model(lr)

            # Evaluate performance
            value = iqa(sr, hr, args.gpu)

            total_mse_value += value[0]
            total_rmse_value += value[1]
            total_psnr_value += value[2]
            total_ssim_value += value[3]
            total_lpips_value += value[4]
            total_gmsd_value += value[5]

            progress_bar.set_description(
                f"[{i + 1}/{len(dataloader)}] "
                f"PSNR: {total_psnr_value / (i + 1):6.2f} "
                f"SSIM: {total_ssim_value / (i + 1):6.4f}")

            images = torch.cat([bicubic, sr, hr], -1)
            vutils.save_image(images,
                              os.path.join("benchmarks", f"{i + 1}.bmp"),
                              padding=10)

    print(f"Performance average results:\n")
    print(f"indicator Score\n")
    print(f"--------- -----\n")
    print(f"MSE       {total_mse_value / len(dataloader):6.4f}\n"
          f"RMSE      {total_rmse_value / len(dataloader):6.4f}\n"
          f"PSNR      {total_psnr_value / len(dataloader):6.2f}\n"
          f"SSIM      {total_ssim_value / len(dataloader):6.4f}\n"
          f"LPIPS     {total_lpips_value / len(dataloader):6.4f}\n"
          f"GMSD      {total_gmsd_value / len(dataloader):6.4f}")
def main(args):
    # Initialize all evaluation criteria.
    total_mse_value, total_rmse_value, total_psnr_value, total_ssim_value, total_gmsd_value = 0.0, 0.0, 0.0, 0.0, 0.0
    if args.seed is not None:
        # In order to make the model repeatable, the first step is to set random seeds, and the second step is to set
        # convolution algorithm.
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        logger.warning("You have chosen to seed testing. "
                       "This will turn on the CUDNN deterministic setting, "
                       "which can slow down your testing considerably! "
                       "You may see unexpected behavior when restarting "
                       "from checkpoints.")
        # for the current configuration, so as to optimize the operation efficiency.
        cudnn.benchmark = True
        # Ensure that every time the same input returns the same result.
        cudnn.deterministic = True

    # Build a super-resolution model, if model_ If path is defined, the specified model weight will be loaded.
    model = configure(args)
    # If special choice model path.
    if args.model_path is not None:
        logger.info(f"You loaded the specified weight. Load weights from `{os.path.abspath(args.model_path)}`.")
        model.load_state_dict(torch.load(args.model_path, map_location=torch.device("cpu")))
    # Switch model to eval mode.
    model.eval()

    # If the GPU is available, load the model into the GPU memory. This speed.
    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")

    # Selection of appropriate treatment equipment.
    dataset = CustomTestDataset(root=os.path.join(args.data, "test"),
                                image_size=args.image_size,
                                upscale_factor=args.upscale_factor)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             pin_memory=True,
                                             num_workers=args.workers)

    # Needs to reconstruct the low resolution image without the gradient information of the reconstructed image.
    with torch.no_grad():
        # Start evaluate model performance.
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, (lr, bicubic, hr) in progress_bar:
            # Move data to special device.
            if args.gpu is not None:
                lr = lr.cuda(args.gpu, non_blocking=True)
                bicubic = bicubic.cuda(args.gpu, non_blocking=True)
                hr = hr.cuda(args.gpu, non_blocking=True)

            # The low resolution image is reconstructed to the super resolution image.
            sr = model(lr)

            # The reconstructed image and the reference image are evaluated once.
            value = iqa(sr, hr, args.gpu)

            # The values of various evaluation indexes are accumulated.
            total_mse_value += value[0]
            total_rmse_value += value[1]
            total_psnr_value += value[2]
            total_ssim_value += value[3]
            total_gmsd_value += value[4]

            # Output as scrollbar style.
            progress_bar.set_description(f"[{i + 1}/{len(dataloader)}] "
                                         f"PSNR: {total_psnr_value / (i + 1):6.2f} "
                                         f"SSIM: {total_ssim_value / (i + 1):6.4f}")

            # Merge three images into one line for visualization.
            # Save a series of reconstruction results.
            vutils.save_image(torch.cat([bicubic, sr, hr], dim=-1),
                              os.path.join("benchmarks", f"{i + 1}.bmp"),
                              padding=10,
                              normalize=True)

    print(f"Performance average results:\n")
    print(f"Indicator score\n")
    print(f"--------- -----\n")
    print(f"MSE       {total_mse_value / len(dataloader):6.4f}\n"
          f"RMSE      {total_rmse_value / len(dataloader):6.4f}\n"
          f"PSNR      {total_psnr_value / len(dataloader):6.2f}\n"
          f"SSIM      {total_ssim_value / len(dataloader):6.4f}\n"
          f"GMSD      {total_gmsd_value / len(dataloader):6.4f}\n")
from rfb_esrgan_pytorch.utils.common import configure

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

logger = logging.getLogger(__name__)
logging.basicConfig(format="[ %(levelname)s ] %(message)s", level=logging.INFO)

parser = argparse.ArgumentParser(
    description=
    "Perceptual Extreme Super Resolution Network with Receptive Field Block.")
parser.add_argument("-a",
                    "--arch",
                    metavar="ARCH",
                    default="rfb",
                    choices=model_names,
                    help="Model architecture: " + " | ".join(model_names) +
                    ". (Default: `rfb`)")
parser.add_argument("--model-path",
                    type=str,
                    metavar="PATH",
                    required=True,
                    help="Path to latest checkpoint for model.")
args = parser.parse_args()

model = configure(args)
model.load_state_dict(torch.load(args.model_path)["state_dict"])
torch.save(model.state_dict(), "Generator.pth")
logger.info("Model convert done.")