Esempio n. 1
0
    def __init__(self, args):
        self.args = args
        self.model, self.device = configure(args)
        # Image preprocessing operation
        self.tensor2pil = transforms.ToPILImage()

        self.video_capture = cv2.VideoCapture(args.file)
        # Prepare to write the processed image into the video.
        self.fps = self.video_capture.get(cv2.CAP_PROP_FPS)
        self.total_frames = int(
            self.video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
        # Set video size
        self.size = (int(self.video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
                     int(self.video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
        self.sr_size = (self.size[0] * args.upscale_factor,
                        self.size[1] * args.upscale_factor)
        self.pare_size = (self.sr_size[0] * 2 + 10,
                          self.sr_size[1] + 10 + self.sr_size[0] // 5 - 9)
        # Video write loader.
        self.sr_writer = cv2.VideoWriter(
            os.path.join(
                "video",
                f"sr_{args.upscale_factor}x_{os.path.basename(args.file)}"),
            cv2.VideoWriter_fourcc(*"MPEG"), self.fps, self.sr_size)
        self.compare_writer = cv2.VideoWriter(
            os.path.join(
                "video",
                f"compare_{args.upscale_factor}x_{os.path.basename(args.file)}"
            ), cv2.VideoWriter_fourcc(*"MPEG"), self.fps, self.pare_size)
Esempio n. 2
0
def main_worker(gpu, args):
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for training.")

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    # Set eval mode.
    model.eval()

    cudnn.benchmark = True

    # Get image filename.
    filename = os.path.basename(args.lr)

    # Read all pictures.
    lr = Image.open(args.lr)
    bicubic = transforms.Resize(
        (lr.size[1] * args.upscale_factor, lr.size[0] * args.upscale_factor),
        Mode.BICUBIC)(lr)
    lr = process_image(lr, args.gpu)
    bicubic = process_image(bicubic, args.gpu)

    with torch.no_grad():
        sr = model(lr)

    if args.hr:
        hr = process_image(Image.open(args.hr), args.gpu)
        vutils.save_image(hr, os.path.join("test", f"hr_{filename}"))
        images = torch.cat([bicubic, sr, hr], dim=-1)

        value = iqa(sr, hr, args.gpu)
        print(f"Performance avg results:\n")
        print(f"indicator Score\n")
        print(f"--------- -----\n")
        print(f"MSE       {value[0]:6.4f}\n"
              f"RMSE      {value[1]:6.4f}\n"
              f"PSNR      {value[2]:6.2f}\n"
              f"SSIM      {value[3]:6.4f}\n"
              f"LPIPS     {value[4]:6.4f}\n"
              f"GMSD      {value[5]:6.4f}\n")
    else:
        images = torch.cat([bicubic, sr], dim=-1)

    vutils.save_image(lr, os.path.join("test", f"lr_{filename}"))
    vutils.save_image(bicubic, os.path.join("test", f"bicubic_{filename}"))
    vutils.save_image(sr, os.path.join("test", f"sr_{filename}"))
    vutils.save_image(images,
                      os.path.join("test", f"compare_{filename}"),
                      padding=10)
Esempio n. 3
0
def main_worker(gpu, args):
    global best_mse_value, best_rmse_value, best_psnr_value, best_ssim_value, best_lpips_value, best_gmsd_value
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for training.")

    cudnn.benchmark = True

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")

    # Read all pictures.
    lr = process_image(Image.open(args.lr), args.gpu)
    hr = process_image(Image.open(args.hr), args.gpu)

    model_paths = glob(
        os.path.join(f"{args.model_dir}", "Generator_epoch*.pth"))
    best_model = model_paths[0]

    for model_path in model_paths:
        print(f"Process `{model_path}`")
        value = inference(lr, hr, model, model_path, gpu)

        is_best = value[2] > best_psnr_value

        if is_best:
            best_model = os.path.basename(model_path)
            best_mse_value = value[0]
            best_rmse_value = value[1]
            best_psnr_value = value[2]
            best_ssim_value = value[3]
            best_lpips_value = value[4]
            best_gmsd_value = value[5]

    print("\n##################################################")
    print(f"Best model: `{best_model}`.")
    print(f"indicator Score")
    print(f"--------- -----")
    print(f"MSE       {best_mse_value:6.4f}"
          f"RMSE      {best_rmse_value:6.2f}"
          f"PSNR      {best_psnr_value:6.2f}\n"
          f"SSIM      {best_ssim_value:6.4f}\n"
          f"LPIPS     {best_lpips_value:6.4f}\n"
          f"GMSD      {best_gmsd_value:6.4f}")
    print(f"--------- -----")
    print("##################################################\n")
Esempio n. 4
0
    def __init__(self, args):
        self.args = args
        self.model, self.device = configure(args)

        logger.info("Load testing dataset")
        dataset = BaseTestDataset(root=os.path.join(args.data, "test"),
                                  image_size=args.image_size)
        self.dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=int(args.workers))

        logger.info(f"Dataset information\n"
                    f"\tDataset dir is `{os.getcwd()}/{args.data}/test`\n"
                    f"\tBatch size is {args.batch_size}\n"
                    f"\tWorkers is {int(args.workers)}\n"
                    f"\tLoad dataset to CUDA")
Esempio n. 5
0
def main_worker(gpu, args):
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for testing.")

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    cudnn.benchmark = True

    # Set eval mode.
    model.eval()

    # Get video filename.
    filename = os.path.basename(args.file)

    # Image preprocessing operation
    tensor2pil = transforms.ToPILImage()

    video_capture = cv2.VideoCapture(args.file)
    # Prepare to write the processed image into the video.
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    # Set video size
    size = (int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    sr_size = (size[0] * args.upscale_factor, size[1] * args.upscale_factor)
    pare_size = (sr_size[0] * 2 + 10, sr_size[1] + 10 + sr_size[0] // 5 - 9)
    # Video write loader.
    sr_writer = cv2.VideoWriter(
        os.path.join("videos", f"sr_{args.upscale_factor}x_{filename}"),
        cv2.VideoWriter_fourcc(*"MPEG"), fps, sr_size)
    compare_writer = cv2.VideoWriter(
        os.path.join("videos", f"compare_{args.upscale_factor}x_{filename}"),
        cv2.VideoWriter_fourcc(*"MPEG"), fps, pare_size)

    # read frame.
    with torch.no_grad():
        success, raw_frame = video_capture.read()
        progress_bar = tqdm(
            range(total_frames),
            desc="[processing video and saving/view result videos]")
        for _ in progress_bar:
            if success:
                # Read image to tensor and transfer to the specified device for processing.
                lr = process_image(raw_frame, args.gpu)

                sr = model(lr)

                sr = sr.cpu()
                sr = sr.data[0].numpy()
                sr *= 255.0
                sr = (np.uint8(sr)).transpose((1, 2, 0))
                # save sr video
                sr_writer.write(sr)

                # make compared video and crop shot of left top\right top\center\left bottom\right bottom
                sr = tensor2pil(sr)
                # Five areas are selected as the bottom contrast map.
                crop_sr_images = transforms.FiveCrop(size=sr.width // 5 -
                                                     9)(sr)
                crop_sr_images = [
                    np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(image))
                    for image in crop_sr_images
                ]
                sr = transforms.Pad(padding=(5, 0, 0, 5))(sr)
                # Five areas in the contrast map are selected as the bottom contrast map
                compare_image = transforms.Resize(
                    (sr_size[1], sr_size[0]),
                    interpolation=Mode.BICUBIC)(tensor2pil(raw_frame))
                crop_compare_images = transforms.FiveCrop(
                    size=compare_image.width // 5 - 9)(compare_image)
                crop_compare_images = [
                    np.asarray(transforms.Pad((0, 5, 10, 0))(image))
                    for image in crop_compare_images
                ]
                compare_image = transforms.Pad(padding=(0, 0, 5,
                                                        5))(compare_image)
                # concatenate all the pictures to one single picture
                # 1. Mosaic the left and right images of the video.
                top_image = np.concatenate(
                    (np.asarray(compare_image), np.asarray(sr)), axis=1)
                # 2. Mosaic the bottom left and bottom right images of the video.
                bottom_image = np.concatenate(crop_compare_images +
                                              crop_sr_images,
                                              axis=1)
                bottom_image_height = int(top_image.shape[1] /
                                          bottom_image.shape[1] *
                                          bottom_image.shape[0])
                bottom_image_width = top_image.shape[1]
                # 3. Adjust to the right size.
                bottom_image = np.asarray(
                    transforms.Resize(
                        (bottom_image_height,
                         bottom_image_width))(tensor2pil(bottom_image)))
                # 4. Combine the bottom zone with the upper zone.
                final_image = np.concatenate((top_image, bottom_image))

                # save compare video
                compare_writer.write(final_image)

                if args.view:
                    # display video
                    cv2.imshow("LR video convert SR video ", final_image)
                    if cv2.waitKey(1) & 0xFF == ord("q"):
                        break

                # next frame
                success, raw_frame = video_capture.read()
Esempio n. 6
0
def main_worker(gpu, ngpus_per_node, args):
    global best_psnr
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for training.")

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    generator = configure(args)
    discriminator = discriminator_for_vgg(args.image_size)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            discriminator.cuda(args.gpu)
            generator.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            discriminator = nn.parallel.DistributedDataParallel(
                module=discriminator, device_ids=[args.gpu])
            generator = nn.parallel.DistributedDataParallel(
                module=generator, device_ids=[args.gpu])
        else:
            discriminator.cuda()
            generator.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            discriminator = nn.parallel.DistributedDataParallel(discriminator)
            generator = nn.parallel.DistributedDataParallel(generator)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        discriminator = discriminator.cuda(args.gpu)
        generator = generator.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith("alexnet") or args.arch.startswith("vgg"):
            discriminator.features = torch.nn.DataParallel(
                discriminator.features)
            generator.features = torch.nn.DataParallel(generator.features)
            discriminator.cuda()
            generator.cuda()
        else:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()

    # Loss = content loss + 0.001 * adversarial loss
    pixel_criterion = nn.MSELoss().cuda(args.gpu)
    # We use VGG5.4 as our feature extraction method by default.
    content_criterion = VGGLoss().cuda(args.gpu)
    adversarial_criterion = nn.BCELoss().cuda(args.gpu)
    logger.info(f"Losses function information:\n"
                f"\tPixel:       MSELoss\n"
                f"\tContent:     VGG19_36th\n"
                f"\tAdversarial: BCELoss")

    # All optimizer function and scheduler function.
    psnr_optimizer = torch.optim.Adam(generator.parameters(),
                                      lr=args.psnr_lr,
                                      betas=(0.9, 0.999))
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.gan_lr,
                                               betas=(0.9, 0.999))
    generator_optimizer = torch.optim.Adam(generator.parameters(),
                                           lr=args.gan_lr,
                                           betas=(0.9, 0.999))
    discriminator_scheduler = torch.optim.lr_scheduler.StepLR(
        discriminator_optimizer, args.gan_epochs // 2, 0.1)
    generator_scheduler = torch.optim.lr_scheduler.StepLR(
        generator_optimizer, args.gan_epochs // 2, 0.1)
    logger.info(
        f"Optimizer information:\n"
        f"\tPSNR learning rate:          {args.psnr_lr}\n"
        f"\tDiscriminator learning rate: {args.gan_lr}\n"
        f"\tGenerator learning rate:     {args.gan_lr}\n"
        f"\tPSNR optimizer:              Adam, [betas=(0.9,0.999)]\n"
        f"\tDiscriminator optimizer:     Adam, [betas=(0.9,0.999)]\n"
        f"\tGenerator optimizer:         Adam, [betas=(0.9,0.999)]\n"
        f"\tPSNR scheduler:              None\n"
        f"\tDiscriminator scheduler:     StepLR, [step_size=self.gan_epochs // 2, gamma=0.1]\n"
        f"\tGenerator scheduler:         StepLR, [step_size=self.gan_epochs // 2, gamma=0.1]"
    )

    logger.info("Load training dataset")
    # Selection of appropriate treatment equipment.
    train_dataset = BaseTrainDataset(root=os.path.join(args.data, "train"),
                                     image_size=args.image_size,
                                     upscale_factor=args.upscale_factor)
    test_dataset = BaseTestDataset(root=os.path.join(args.data, "test"),
                                   image_size=args.image_size,
                                   upscale_factor=args.upscale_factor)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        pin_memory=True,
        sampler=train_sampler,
        num_workers=args.workers)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  num_workers=args.workers)

    logger.info(f"Dataset information:\n"
                f"\tTrain Path:              {os.getcwd()}/{args.data}/train\n"
                f"\tTest Path:               {os.getcwd()}/{args.data}/test\n"
                f"\tNumber of train samples: {len(train_dataset)}\n"
                f"\tNumber of test samples:  {len(test_dataset)}\n"
                f"\tNumber of train batches: {len(train_dataloader)}\n"
                f"\tNumber of test batches:  {len(test_dataloader)}\n"
                f"\tShuffle of train:        True\n"
                f"\tShuffle of test:         False\n"
                f"\tSampler of train:        {bool(train_sampler)}\n"
                f"\tSampler of test:         None\n"
                f"\tWorkers of train:        {args.workers}\n"
                f"\tWorkers of test:         {args.workers}")

    # optionally resume from a checkpoint
    if args.resume_psnr:
        if os.path.isfile(args.resume_psnr):
            logger.info(f"Loading checkpoint '{args.resume_psnr}'.")
            if args.gpu is None:
                checkpoint = torch.load(args.resume_psnr)
            else:
                # Map model to be loaded to specified single gpu.
                checkpoint = torch.load(args.resume_psnr,
                                        map_location=f"cuda:{args.gpu}")
            args.start_psnr_epoch = checkpoint["epoch"]
            best_psnr = checkpoint["best_psnr"]
            if args.gpu is not None:
                # best_psnr may be from a checkpoint from a different GPU
                best_psnr = best_psnr.to(args.gpu)
            generator.load_state_dict(checkpoint["state_dict"])
            psnr_optimizer.load_state_dict(checkpoint["optimizer"])
            logger.info(
                f"Loaded checkpoint '{args.resume_psnr}' (epoch {checkpoint['epoch']})."
            )
        else:
            logger.info(f"No checkpoint found at '{args.resume_psnr}'.")

    if args.resume_d or args.resume_g:
        if os.path.isfile(args.resume_d) or os.path.isfile(args.resume_g):
            logger.info(f"Loading checkpoint '{args.resume_d}'.")
            logger.info(f"Loading checkpoint '{args.resume_g}'.")
            if args.gpu is None:
                checkpoint_d = torch.load(args.resume_d)
                checkpoint_g = torch.load(args.resume_g)
            else:
                # Map model to be loaded to specified single gpu.
                checkpoint_d = torch.load(args.resume_d,
                                          map_location=f"cuda:{args.gpu}")
                checkpoint_g = torch.load(args.resume_g,
                                          map_location=f"cuda:{args.gpu}")
            args.start_gan_epoch = checkpoint_g["epoch"]
            best_psnr = checkpoint_g["best_psnr"]
            if args.gpu is not None:
                # best_psnr may be from a checkpoint from a different GPU
                best_psnr = best_psnr.to(args.gpu)
            discriminator.load_state_dict(checkpoint_d["state_dict"])
            discriminator_optimizer.load_state_dict(checkpoint_d["optimizer"])
            generator.load_state_dict(checkpoint_g["state_dict"])
            generator_optimizer.load_state_dict(checkpoint_g["optimizer"])
            logger.info(
                f"Loaded checkpoint '{args.resume_d}' (epoch {checkpoint_d['epoch']})."
            )
            logger.info(
                f"Loaded checkpoint '{args.resume_g}' (epoch {checkpoint_g['epoch']})."
            )
        else:
            logger.info(
                f"No checkpoint found at '{args.resume_d}' or '{args.resume_g}'."
            )

    cudnn.benchmark = True

    # Create a SummaryWriter at the beginning of training.
    psnr_writer = SummaryWriter(f"runs/{args.arch}_psnr_logs")
    gan_writer = SummaryWriter(f"runs/{args.arch}_gan_logs")

    logger.info(f"Train information:\n"
                f"\tPSNR-oral epochs: {args.psnr_epochs}\n"
                f"\tGAN-oral epochs:  {args.gan_epochs}")

    for epoch in range(args.start_psnr_epoch, args.psnr_epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_psnr(train_dataloader=train_dataloader,
                   generator=generator,
                   pixel_criterion=pixel_criterion,
                   psnr_optimizer=psnr_optimizer,
                   epoch=epoch,
                   writer=psnr_writer,
                   args=args)

        # Test for every epoch.
        psnr, ssim, lpips, gmsd = test(generator, test_dataloader, args.gpu)
        gan_writer.add_scalar("Test/PSNR", psnr, epoch + 1)
        gan_writer.add_scalar("Test/SSIM", ssim, epoch + 1)
        gan_writer.add_scalar("Test/LPIPS", lpips, epoch + 1)
        gan_writer.add_scalar("Test/GMSD", gmsd, epoch + 1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            torch.save(
                {
                    "epoch":
                    epoch + 1,
                    "arch":
                    args.arch,
                    "state_dict":
                    generator.module.state_dict()
                    if args.multiprocessing_distributed else
                    generator.state_dict(),
                    "optimizer":
                    psnr_optimizer.state_dict(),
                }, os.path.join("weights", f"PSNR_epoch{epoch}.pth"))
            if psnr > best_psnr:
                best_psnr = max(psnr, best_psnr)
                torch.save(generator.state_dict(),
                           os.path.join("weights", f"PSNR.pth"))

    # Load best model weight.
    best_psnr = 0.0
    generator.load_state_dict(
        torch.load(os.path.join("weights", f"PSNR.pth"),
                   map_location=f"cuda:{args.gpu}"))

    for epoch in range(args.start_gan_epoch, args.gan_epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_gan(train_dataloader=train_dataloader,
                  discriminator=discriminator,
                  generator=generator,
                  content_criterion=content_criterion,
                  adversarial_criterion=adversarial_criterion,
                  discriminator_optimizer=discriminator_optimizer,
                  generator_optimizer=generator_optimizer,
                  epoch=epoch,
                  writer=gan_writer,
                  args=args)

        discriminator_scheduler.step()
        generator_scheduler.step()

        # Test for every epoch.
        psnr, ssim, lpips, gmsd = test(generator, test_dataloader, args.gpu)
        gan_writer.add_scalar("Test/PSNR", psnr, epoch + 1)
        gan_writer.add_scalar("Test/SSIM", ssim, epoch + 1)
        gan_writer.add_scalar("Test/LPIPS", lpips, epoch + 1)
        gan_writer.add_scalar("Test/GMSD", gmsd, epoch + 1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            torch.save(
                {
                    "epoch":
                    epoch + 1,
                    "arch":
                    "vgg",
                    "state_dict":
                    discriminator.module.state_dict()
                    if args.multiprocessing_distributed else
                    discriminator.state_dict(),
                    "optimizer":
                    discriminator_optimizer.state_dict()
                }, os.path.join("weights", f"Discriminator_epoch{epoch}.pth"))
            torch.save(
                {
                    "epoch":
                    epoch + 1,
                    "arch":
                    args.arch,
                    "state_dict":
                    generator.module.state_dict()
                    if args.multiprocessing_distributed else
                    generator.state_dict(),
                    "optimizer":
                    generator_optimizer.state_dict()
                }, os.path.join("weights", f"Generator_epoch{epoch}.pth"))
            if psnr > best_psnr:
                best_psnr = max(psnr, best_psnr)
                torch.save(generator.state_dict(),
                           os.path.join("weights", f"GAN.pth"))
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import logging

import torch

import srgan_pytorch.models as models
from srgan_pytorch.utils.common import configure

model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]))

logger = logging.getLogger(__name__)
logging.basicConfig(format="[ %(levelname)s ] %(message)s", level=logging.INFO)

parser = argparse.ArgumentParser("Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network.")
parser.add_argument("-a", "--arch", metavar="ARCH", default="srgan",
                    choices=model_names,
                    help="Model architecture: " +
                         " | ".join(model_names) +
                         ". (Default: srgan)")
parser.add_argument("--model-path", type=str, metavar="PATH", required=True,
                    help="Path to latest checkpoint for model.")
args = parser.parse_args()

model = configure(args)
model.load_state_dict(torch.load(args.model_path)["state_dict"])
torch.save(model.state_dict(), "Generator.pth")
logger.info("Model convert done.")
Esempio n. 8
0
 def __init__(self, args):
     self.args = args
     self.model, self.device = configure(args)
Esempio n. 9
0
def main_worker(gpu, ngpus_per_node, args):
    global total_mse_value, total_rmse_value, total_psnr_value
    global total_ssim_value, total_lpips_value, total_gmsd_value
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for training.")

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = nn.parallel.DistributedDataParallel(module=model,
                                                        device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith("alexnet") or args.arch.startswith("vgg"):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    logger.info("Load testing dataset")
    # Selection of appropriate treatment equipment.
    dataset = BaseTestDataset(root=os.path.join(args.data, "test"),
                              image_size=args.image_size,
                              upscale_factor=args.upscale_factor)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             pin_memory=True,
                                             num_workers=args.workers)
    logger.info(f"Dataset information:\n"
                f"\tPath:              {os.getcwd()}/{args.data}/test\n"
                f"\tNumber of samples: {len(dataset)}\n"
                f"\tNumber of batches: {len(dataloader)}\n"
                f"\tShuffle:           False\n"
                f"\tSampler:           None\n"
                f"\tWorkers:           {args.workers}")

    cudnn.benchmark = True

    # Set eval mode.
    model.eval()

    # Start evaluate model performance.
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))

    for i, (lr, bicubic, hr) in progress_bar:
        # Move data to special device.
        if args.gpu is not None:
            lr = lr.cuda(args.gpu, non_blocking=True)
            bicubic = bicubic.cuda(args.gpu, non_blocking=True)
            hr = hr.cuda(args.gpu, non_blocking=True)

        with torch.no_grad():
            sr = model(lr)

        # Evaluate performance
        value = iqa(sr, hr, args.gpu)

        total_mse_value += value[0]
        total_rmse_value += value[1]
        total_psnr_value += value[2]
        total_ssim_value += value[3]
        total_lpips_value += value[4]
        total_gmsd_value += value[5]

        progress_bar.set_description(
            f"[{i + 1}/{len(dataloader)}] "
            f"PSNR: {total_psnr_value / (i + 1):6.2f} "
            f"SSIM: {total_ssim_value / (i + 1):6.4f}")

        images = torch.cat([bicubic, sr, hr], dim=-1)
        vutils.save_image(images,
                          os.path.join("benchmark", f"{i + 1}.bmp"),
                          padding=10)

    print(f"Performance average results:\n")
    print(f"indicator Score\n")
    print(f"--------- -----\n")
    print(f"MSE       {total_mse_value / len(dataloader):6.4f}\n"
          f"RMSE      {total_rmse_value / len(dataloader):6.4f}\n"
          f"PSNR      {total_psnr_value / len(dataloader):6.2f}\n"
          f"SSIM      {total_ssim_value / len(dataloader):6.4f}\n"
          f"LPIPS     {total_lpips_value / len(dataloader):6.4f}\n"
          f"GMSD      {total_gmsd_value / len(dataloader):6.4f}")
Esempio n. 10
0
def main_worker(gpu, args):
    global total_mse_value, total_rmse_value, total_psnr_value, total_ssim_value, total_lpips_value, total_gmsd_value
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for testing.")

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    logger.info("Load testing dataset.")
    # Selection of appropriate treatment equipment.
    dataset = BaseTestDataset(os.path.join(args.data, "test"), args.image_size,
                              args.upscale_factor)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             pin_memory=True,
                                             num_workers=args.workers)
    logger.info(f"Dataset information:\n"
                f"\tPath:              {os.getcwd()}/{args.data}/test\n"
                f"\tNumber of samples: {len(dataset)}\n"
                f"\tNumber of batches: {len(dataloader)}\n"
                f"\tShuffle:           False\n"
                f"\tSampler:           None\n"
                f"\tWorkers:           {args.workers}")

    cudnn.benchmark = True

    # Set eval mode.
    model.eval()

    with torch.no_grad():
        # Start evaluate model performance.
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, (lr, bicubic, hr) in progress_bar:
            # Move data to special device.
            if args.gpu is not None:
                lr = lr.cuda(args.gpu, non_blocking=True)
                bicubic = bicubic.cuda(args.gpu, non_blocking=True)
                hr = hr.cuda(args.gpu, non_blocking=True)

            sr = model(lr)

            # Evaluate performance
            value = iqa(sr, hr, args.gpu)

            total_mse_value += value[0]
            total_rmse_value += value[1]
            total_psnr_value += value[2]
            total_ssim_value += value[3]
            total_lpips_value += value[4]
            total_gmsd_value += value[5]

            progress_bar.set_description(
                f"[{i + 1}/{len(dataloader)}] "
                f"PSNR: {total_psnr_value / (i + 1):6.2f} "
                f"SSIM: {total_ssim_value / (i + 1):6.4f}")

            images = torch.cat([bicubic, sr, hr], dim=-1)
            vutils.save_image(images,
                              os.path.join("benchmarks", f"{i + 1}.bmp"),
                              padding=10)

    print(f"Performance average results:\n")
    print(f"indicator Score\n")
    print(f"--------- -----\n")
    print(f"MSE       {total_mse_value / len(dataloader):6.4f}\n"
          f"RMSE      {total_rmse_value / len(dataloader):6.4f}\n"
          f"PSNR      {total_psnr_value / len(dataloader):6.2f}\n"
          f"SSIM      {total_ssim_value / len(dataloader):6.4f}\n"
          f"LPIPS     {total_lpips_value / len(dataloader):6.4f}\n"
          f"GMSD      {total_gmsd_value / len(dataloader):6.4f}")