def forward(self, predictions, targets):
     if predictions.shape[1] == 3:
         predictions = kc.rgb_to_grayscale(predictions)
         targets = kc.rgb_to_grayscale(targets)
     mse = F.mse_loss(predictions, targets)
     psnr = 10 * torch.log10(self.max_val ** 2 / mse)
     return psnr
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", choices=["srcnn", "srgan"], required=True)
    parser.add_argument("--ckpt", type=str, required=True)
    opt = parser.parse_args()

    # load model class
    if opt.model == "srcnn":
        model = models.SRCNNModel
    elif opt.model == "srgan":
        model = models.SRGANModel
    else:
        raise RuntimeError(opt.model)

    # load model state from ckpt file
    model = model.load_from_metrics(
        weights_path=opt.ckpt,
        tags_csv=Path(opt.ckpt).parent.parent / "meta_tags.csv",
        map_location=None,
    )
    model.eval()
    model.freeze()

    save_dir = Path(opt.ckpt)
    save_dir = save_dir.with_name(Path(opt.ckpt).stem.replace("_ckpt_", ""))
    save_dir.mkdir(exist_ok=True)

    criterion_psnr = models.losses.PSNR()
    criterion_ssim = SSIM(window_size=11, reduction="mean")

    for dataset, dataloader in model.test_dataloader.items():
        psnr_mean = 0
        ssim_mean = 0

        tbar = tqdm(dataloader)
        for batch in tbar:
            img_name = batch["path"][0]
            img_lr = batch["lr"]
            img_hr = batch["hr"]
            img_sr = model(img_lr)

            img_hr_ = rgb_to_grayscale(img_hr)
            img_sr_ = rgb_to_grayscale(img_sr)

            psnr = criterion_psnr(img_hr_, img_sr_).item()
            ssim = 1 - criterion_ssim(img_hr_, img_sr_).item()
            psnr_mean += psnr
            ssim_mean += ssim

            save_image(img_sr, save_dir / f"{dataset}_{img_name}.png", nrow=1)

        psnr_mean /= len(dataloader)
        ssim_mean /= len(dataloader)
        print(f"[{dataset}] PSNR: {psnr_mean:.4}, SSIM: {ssim_mean:.4}")
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', choices=['srcnn', 'srgan'], required=True)
    parser.add_argument('--ckpt', type=str, required=True)
    opt = parser.parse_args()

    # load model class
    if opt.model == 'srcnn':
        Model = models.SRCNNModel
    elif opt.model == 'srgan':
        Model = models.SRGANModel

    # load model state from ckpt file
    model = Model.load_from_metrics(weights_path=opt.ckpt,
                                    tags_csv=Path(opt.ckpt).parent.parent /
                                    'meta_tags.csv',
                                    on_gpu=True,
                                    map_location=None)
    model.eval()
    model.freeze()

    save_dir = Path(opt.ckpt)
    save_dir = save_dir.with_name(Path(opt.ckpt).stem.replace('_ckpt_', ''))
    save_dir.mkdir(exist_ok=True)

    criterion_PSNR = models.losses.PSNR()
    criterion_SSIM = SSIM(window_size=11, reduction='mean')

    for dataset, dataloader in model.test_dataloader.items():
        psnr_mean = 0
        ssim_mean = 0

        tbar = tqdm(dataloader)
        for batch in tbar:
            img_name = batch['path'][0]
            img_lr = batch['lr']
            img_hr = batch['hr']
            img_sr = model(img_lr)

            img_hr_ = rgb_to_grayscale(img_hr)
            img_sr_ = rgb_to_grayscale(img_sr)

            psnr = criterion_PSNR(img_hr_, img_sr_).item()
            ssim = 1 - criterion_SSIM(img_hr_, img_sr_).item()
            psnr_mean += psnr
            ssim_mean += ssim

            save_image(img_sr, save_dir / f'{dataset}_{img_name}.png', nrow=1)

        psnr_mean /= len(dataloader)
        ssim_mean /= len(dataloader)
        print(f'[{dataset}] PSNR: {psnr_mean:.4}, SSIM: {ssim_mean:.4}')
Beispiel #4
0
    def test_step(self, batch, batch_idx):
        with torch.no_grad():
            img_lr = batch["lr"]
            img_hr = batch["hr"]
            img_sr = self.forward(img_lr)

            img_hr_ = rgb_to_grayscale(img_hr)
            img_sr_ = rgb_to_grayscale(img_sr)

            psnr = self.criterion_PSNR(img_sr_, img_hr_)
            ssim = 1 - self.criterion_SSIM(img_sr_, img_hr_)  # invert

        return {"psnr": psnr, "ssim": ssim}
Beispiel #5
0
    def validation_step(self, batch, batch_nb):
        with torch.no_grad():
            img_lr = batch['lr']
            img_hr = batch['hr']
            img_sr = self.forward(img_lr)

            img_hr_ = rgb_to_grayscale(img_hr)
            img_sr_ = rgb_to_grayscale(img_sr)

            psnr = self.criterion_PSNR(img_sr_, img_hr_)
            ssim = 1 - self.criterion_SSIM(img_sr_, img_hr_)  # invert

        return {'psnr': psnr, 'ssim': ssim}
Beispiel #6
0
def apply_grayscale(input: torch.Tensor,
                    params: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Apply Gray Scale on a tensor image or a batch of tensor images with given random parameters.
    Input should be a tensor of shape (3, H, W) or a batch of tensors :math:`(*, 3, H, W)`.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        params (Dict[str, torch.Tensor]):
            - params['batch_prob']: A boolean tensor that indicating whether if to transform an image in a batch.

    Returns:
        torch.Tensor: The grayscaled input
    """
    # TODO: params validation

    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    if not _validate_input_shape(input, 1, 3):
        raise ValueError(
            f"Input size must have a shape of (*, 3, H, W). Got {input.shape}")

    grayscale: torch.Tensor = input.clone()

    to_gray = params['batch_prob'].to(input.device)

    grayscale[to_gray] = rgb_to_grayscale(input[to_gray])

    return grayscale
Beispiel #7
0
def apply_grayscale(input: torch.Tensor,
                    params: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Apply Gray Scale on a tensor image or a batch of tensor images with given random parameters.
    Input should be a tensor of shape (3, H, W) or a batch of tensors :math:`(*, 3, H, W)`.

    Args:
        params (dict): A dict that must have {'batch_prob': torch.Tensor}. Can be generated from
        kornia.augmentation.random_generator.random_prob_generator
        return_transform (bool): if ``True`` return the matrix describing the transformation applied to each
        input tensor.

    Returns:
        torch.Tensor: The grayscaled input
    """
    # TODO: params validation

    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    if not _validate_input_shape(input, 1, 3):
        raise ValueError(
            f"Input size must have a shape of (*, 3, H, W). Got {input.shape}")

    grayscale: torch.Tensor = input.clone()

    to_gray = params['batch_prob'].to(input.device)

    grayscale[to_gray] = rgb_to_grayscale(input[to_gray])

    return grayscale
Beispiel #8
0
def apply_grayscale(input: torch.Tensor) -> torch.Tensor:
    r"""Apply Gray Scale on a tensor image or a batch of tensor images with given random parameters.

    Input should be a tensor of shape (3, H, W) or a batch of tensors :math:`(*, 3, H, W)`.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).

    Returns:
        torch.Tensor: The grayscaled input
    """
    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    if not _validate_input_shape(input, 1, 3):
        raise ValueError(
            f"Input size must have a shape of (*, 3, H, W). Got {input.shape}")

    grayscale: torch.Tensor = input.clone()

    # Make sure it returns (*, 3, H, W)
    grayscale[:] = rgb_to_grayscale(input)

    return grayscale
Beispiel #9
0
 def preprocess(self, image_1: torch.Tensor,
                image_2: torch.Tensor) -> Dict[str, torch.Tensor]:
     """Preprocess input to the required format."""
     # TODO: probably perform histogram matching here.
     if isinstance(self.matcher, LoFTR) or isinstance(
             self.matcher, LocalFeatureMatcher):
         input_dict: Dict[
             str, torch.Tensor] = {  # LofTR works on grayscale images only
                 "image0": rgb_to_grayscale(image_1),
                 "image1": rgb_to_grayscale(image_2)
             }
     else:
         raise NotImplementedError(
             f"The preprocessor for {self.matcher} has not been implemented."
         )
     return input_dict
Beispiel #10
0
def get_laf_descriptors(img: torch.Tensor,
                        lafs: torch.Tensor,
                        patch_descriptor: nn.Module,
                        patch_size: int = 32,
                        grayscale_descriptor: bool = True) -> torch.Tensor:
    r"""Function to get local descriptors, corresponding to LAFs (keypoints).

    Args:
        img: image features with shape :math:`(B,C,H,W)`.
        lafs: local affine frames :math:`(B,N,2,3)`.
        patch_descriptor: patch descriptor module, e.g. :class:`~kornia.feature.SIFTDescriptor`
            or :class:`~kornia.feature.HardNet`.
        patch_size: patch size in pixels, which descriptor expects.
        grayscale_descriptor: True if ``patch_descriptor`` expects single-channel image.

    Returns:
        Local descriptors of shape :math:`(B,N,D)` where :math:`D` is descriptor size.
    """
    raise_error_if_laf_is_not_valid(lafs)
    patch_descriptor = patch_descriptor.to(img)
    patch_descriptor.eval()

    timg: torch.Tensor = img
    if grayscale_descriptor and img.size(1) == 3:
        timg = rgb_to_grayscale(img)

    patches: torch.Tensor = extract_patches_from_pyramid(
        timg, lafs, patch_size)
    # Descriptor accepts standard tensor [B, CH, H, W], while patches are [B, N, CH, H, W] shape
    # So we need to reshape a bit :)
    B, N, CH, H, W = patches.size()
    return patch_descriptor(patches.view(B * N, CH, H, W)).view(B, N, -1)
Beispiel #11
0
 def apply_transform(self,
                     input: Tensor,
                     params: Dict[str, Tensor],
                     transform: Optional[Tensor] = None) -> Tensor:
     # Make sure it returns (*, 3, H, W)
     grayscale = torch.ones_like(input)
     grayscale[:] = rgb_to_grayscale(input)
     return grayscale
Beispiel #12
0
def apply_grayscale(input: torch.Tensor,
                    params: Dict[str, torch.Tensor],
                    return_transform: bool = False) -> UnionType:
    r"""Apply Gray Scale on a tensor image or a batch of tensor images with given random parameters.
    Input should be a tensor of shape (3, H, W) or a batch of tensors :math:`(*, 3, H, W)`.

    Args:
        params (dict): A dict that must have {'batch_prob': torch.Tensor}. Can be generated from
        kornia.augmentation.random_generator.random_prob_generator
        return_transform (bool): if ``True`` return the matrix describing the transformation applied to each
        input tensor.

    Returns:
        torch.Tensor: The grayscaled input
        torch.Tensor: The applied transformation matrix :math: `(*, 3, 3)` if return_transform flag
        is set to ``True``
    """
    # TODO: params validation

    input = _transform_input(input)
    _validate_input_dtype(
        input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])

    if not _validate_input_shape(input, 1, 3):
        raise ValueError(
            f"Input size must have a shape of (*, 3, H, W). Got {input.shape}")

    if not isinstance(return_transform, bool):
        raise TypeError(
            f"The return_transform flag must be a bool. Got {type(return_transform)}"
        )

    grayscale: torch.Tensor = input.clone()

    to_gray = params['batch_prob'].to(input.device)

    grayscale[to_gray] = rgb_to_grayscale(input[to_gray])
    if return_transform:

        identity: torch.Tensor = torch.eye(3,
                                           device=input.device,
                                           dtype=input.dtype).repeat(
                                               input.shape[0], 1, 1)

        return grayscale, identity

    return grayscale
Beispiel #13
0
def apply_grayscale(input: torch.Tensor) -> torch.Tensor:
    r"""Apply Gray Scale on a tensor image or a batch of tensor images with given random parameters.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape :math:`(*, C, H, W)`.

    Returns:
        torch.Tensor: The grayscaled input with shape :math:`(B, C, H, W)`.
    """
    if not _validate_input_shape(input, 1, 3):
        raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {input.shape}")

    grayscale: torch.Tensor = input.clone()

    # Make sure it returns (*, 3, H, W)
    grayscale[:] = rgb_to_grayscale(input)

    return grayscale
Beispiel #14
0
def snow(inp, snow_mean, snow_std, snow_zoom, snow_thresh, snow_kernel_rad,
         snow_kernel_std, snow_mixin):
    snow_shape = [inp.shape[0], 1, inp.shape[2], inp.shape[3]]
    snow_layer = ch.randn(*snow_shape).cuda() * snow_std + snow_mean
    # Scaling
    zoomed = geometry.transform.rescale(snow_layer, float(snow_zoom))
    trim_top = (zoomed.shape[-1] - inp.shape[-1]) // 2
    snow_layer = zoomed[:, :, trim_top:trim_top + inp.shape[-1],
                        trim_top:trim_top + inp.shape[-1]]

    # Thresholding
    snow_layer[snow_layer < snow_thresh] = 0.
    snow_layer = motion_blur(snow_layer,
                             snow_kernel_rad,
                             snow_kernel_std,
                             angle_offset=-90.)

    grayscaled_inp = ch.max(inp, color.rgb_to_grayscale(inp) * 1.5 + 0.5)
    inp = snow_mixin * inp + (1 - snow_mixin) * grayscaled_inp
    return ch.clamp(inp + snow_layer + geometry.transform.rot180(snow_layer),
                    0, 1)
Beispiel #15
0
def canny(
    input: torch.Tensor,
    low_threshold: float = 0.1,
    high_threshold: float = 0.2,
    kernel_size: Tuple[int, int] = (5, 5),
    sigma: Tuple[float, float] = (1, 1),
    hysteresis: bool = True,
    eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Finds edges of the input image and filters them using the Canny algorithm.

    .. image:: _static/img/canny.png

    Args:
        input: input image tensor with shape :math:`(B,C,H,W)`.
        low_threshold: lower threshold for the hysteresis procedure.
        high_threshold: upper threshold for the hysteresis procedure.
        kernel_size: the size of the kernel for the gaussian blur.
        sigma: the standard deviation of the kernel for the gaussian blur.
        hysteresis: if True, applies the hysteresis edge tracking.
            Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
        eps: regularization number to avoid NaN during backprop.

    Returns:
        - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
        - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.

    Example:
        >>> input = torch.rand(5, 3, 4, 4)
        >>> magnitude, edges = canny(input)  # 5x3x4x4
        >>> magnitude.shape
        torch.Size([5, 1, 4, 4])
        >>> edges.shape
        torch.Size([5, 1, 4, 4])
    """
    if not isinstance(input, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(input)))

    if not len(input.shape) == 4:
        raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input.shape))

    if low_threshold > high_threshold:
        raise ValueError(
            "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format(
                low_threshold, high_threshold
            )
        )

    if low_threshold < 0 and low_threshold > 1:
        raise ValueError(
            "Invalid input threshold. low_threshold should be in range (0,1). Got: {}".format(low_threshold)
        )

    if high_threshold < 0 and high_threshold > 1:
        raise ValueError(
            "Invalid input threshold. high_threshold should be in range (0,1). Got: {}".format(high_threshold)
        )

    device: torch.device = input.device
    dtype: torch.dtype = input.dtype

    # To Grayscale
    if input.shape[1] == 3:
        input = rgb_to_grayscale(input)

    # Gaussian filter
    blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma)

    # Compute the gradients
    gradients: torch.Tensor = spatial_gradient(blurred, normalized=False)

    # Unpack the edges
    gx: torch.Tensor = gradients[:, :, 0]
    gy: torch.Tensor = gradients[:, :, 1]

    # Compute gradient magnitude and angle
    magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
    angle: torch.Tensor = torch.atan2(gy, gx)

    # Radians to Degrees
    angle = rad2deg(angle)

    # Round angle to the nearest 45 degree
    angle = torch.round(angle / 45) * 45

    # Non-maximal suppression
    nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype)
    nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)

    # Get the indices for both directions
    positive_idx: torch.Tensor = (angle / 45) % 8
    positive_idx = positive_idx.long()

    negative_idx: torch.Tensor = ((angle / 45) + 4) % 8
    negative_idx = negative_idx.long()

    # Apply the non-maximum suppresion to the different directions
    channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx)
    channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx)

    channel_select_filtered: torch.Tensor = torch.stack(
        [channel_select_filtered_positive, channel_select_filtered_negative], 1
    )

    is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0

    magnitude = magnitude * is_max

    # Threshold
    edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0)

    low: torch.Tensor = magnitude > low_threshold
    high: torch.Tensor = magnitude > high_threshold

    edges = low * 0.5 + high * 0.5
    edges = edges.to(dtype)

    # Hysteresis
    if hysteresis:
        edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
        hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype)

        while ((edges_old - edges).abs() != 0).any():
            weak: torch.Tensor = (edges == 0.5).float()
            strong: torch.Tensor = (edges == 1).float()

            hysteresis_magnitude: torch.Tensor = F.conv2d(
                edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
            )
            hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
            hysteresis_magnitude = hysteresis_magnitude * weak + strong

            edges_old = edges.clone()
            edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5

        edges = hysteresis_magnitude

    return magnitude, edges
Beispiel #16
0
def train(arg):
    log_writer = None
    if arg.save_logs:
        log_path = './logs/decoder_' + arg.dataset + '_' + arg.split
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        log_writer = SummaryWriter(log_dir=log_path)

        def log(tag, scalar, step):
            log_writer.add_scalar(tag, scalar, step)

        def log_img(tag, img, step):
            log_writer.add_image(tag, img, step)
    else:

        def log(tag, scalar, step):
            pass

        def log_img(tag, img, step):
            pass

    epoch = None
    devices = get_devices_list(arg)

    print('*****  Training decoder  *****')
    print('Training parameters:\n' + '# Dataset:            ' + arg.dataset +
          '\n' + '# Dataset split:      ' + arg.split + '\n' +
          '# Batchsize:          ' + str(arg.batch_size) + '\n' +
          '# Num workers:        ' + str(arg.workers) + '\n' +
          '# PDB:                ' + str(arg.PDB) + '\n' +
          '# Use GPU:            ' + str(arg.cuda) + '\n' +
          '# Start lr:           ' + str(arg.lr) + '\n' +
          '# Max epoch:          ' + str(arg.max_epoch) + '\n' +
          '# Loss type:          ' + arg.loss_type + '\n' +
          '# Resumed model:      ' + str(arg.resume_epoch > 0))
    if arg.resume_epoch > 0:
        print('# Resumed epoch:      ' + str(arg.resume_epoch))

    print('Creating networks ...')
    estimator = create_model_estimator(arg, devices, eval=True)
    estimator.eval()

    regressor = None
    if arg.regressor_loss:
        regressor = create_model_regressor(arg, devices, eval=True)
        regressor.eval()

    decoder = create_model_decoder(arg, devices)
    decoder.train()

    edge = create_model_edge(arg, devices, eval=True)
    edge.eval()

    print('Creating networks done!')

    optimizer_decoder, scheduler_decoder = create_optimizer(
        arg, decoder.parameters(), create_scheduler=True)

    criterion_simple = nn.SmoothL1Loss()
    if arg.cuda:
        criterion_simple = criterion_simple.cuda(device=devices[0])

    criterion_gp = None
    if arg.gp_loss:
        criterion_gp = GPLoss()
        if arg.cuda:
            criterion_gp = criterion_gp.cuda(device=devices[0])

    criterion_cp = None
    if arg.cp_loss:
        criterion_cp = CPLoss()
        if arg.cuda:
            criterion_cp = criterion_cp.cuda(device=devices[0])

    criterion_feature = None
    if arg.feature_loss:
        criterion_feature = FeatureLoss(False, arg.feature_loss_type)
        if arg.cuda:
            criterion_feature = criterion_feature.cuda(device=devices[0])

    criterion_regressor = None
    if regressor is not None:
        if arg.loss_type == 'L2':
            criterion_regressor = nn.MSELoss()
        elif arg.loss_type == 'L1':
            criterion_regressor = nn.L1Loss()
        elif arg.loss_type == 'smoothL1':
            criterion_regressor = nn.SmoothL1Loss()
        elif arg.loss_type == 'wingloss':
            criterion_regressor = WingLoss(omega=arg.wingloss_omega,
                                           epsilon=arg.wingloss_epsilon)
        else:
            criterion_regressor = AdaptiveWingLoss(
                arg.wingloss_omega,
                theta=arg.wingloss_theta,
                epsilon=arg.wingloss_epsilon,
                alpha=arg.wingloss_alpha)
        if arg.cuda:
            criterion_regressor = criterion_regressor.cuda(device=devices[0])

    print('Loading dataset ...')
    trainset = DecoderDataset(arg, dataset=arg.dataset, split=arg.split)
    dataloader = torch.utils.data.DataLoader(trainset,
                                             batch_size=arg.batch_size,
                                             shuffle=arg.shuffle,
                                             num_workers=arg.workers,
                                             pin_memory=True)
    steps_per_epoch = len(dataloader)

    mean = torch.FloatTensor(means_color[arg.dataset][arg.split])
    std = torch.FloatTensor(stds_color[arg.dataset][arg.split])
    norm_min = (0 - mean) / std
    norm_max = (255 - mean) / std
    norm_range = norm_max - norm_min

    mean_gray = means_gray[arg.dataset][arg.split]
    std_gray = stds_gray[arg.dataset][arg.split]

    if arg.cuda:
        mean = mean.cuda(device=devices[0])
        std = std.cuda(device=devices[0])
        norm_min = norm_min.cuda(device=devices[0])
        # norm_max = norm_max.cuda(device=devices[0])
        norm_range = norm_range.cuda(device=devices[0])
    print('Loading dataset done!')

    # evolving training
    print('Start training ...')
    for epoch in range(arg.resume_epoch, arg.max_epoch):
        global_step_base = epoch * steps_per_epoch
        forward_times_per_epoch, sum_loss_decoder, = 0, 0.

        for data in tqdm.tqdm(dataloader):
            forward_times_per_epoch += 1
            global_step = global_step_base + forward_times_per_epoch

            input_images, input_images_denorm, gt_coords_xy = data

            # show_img(tensor_to_image(rgb_to_bgr(denormalize(input_images_denorm[0].unsqueeze(0), mean, std)).squeeze()), 'target', wait=0, keep=True)

            if arg.cuda:
                input_images = input_images.cuda(device=devices[0])
                input_images_denorm = input_images_denorm.cuda(
                    device=devices[0])
                gt_coords_xy = gt_coords_xy.cuda(device=devices[0])

            with torch.no_grad():
                heatmaps_orig = estimator(input_images)[-1]
                min = torch.min(heatmaps_orig)
                max = torch.max(heatmaps_orig)
                rng = max - min
                heatmaps = rescale_0_1(heatmaps_orig, min, rng)
                heatmaps = edge(heatmaps)
                min = torch.min(heatmaps)
                max = torch.max(heatmaps)
                rng = max - min
                heatmaps = rescale_0_1(heatmaps, min, rng).detach()

            fake_images_norm = decoder(heatmaps)
            fake_images_denorm = derescale_0_1(fake_images_norm, norm_min,
                                               norm_range)

            optimizer_decoder.zero_grad()

            loss_simple = criterion_simple(fake_images_denorm,
                                           input_images_denorm)
            loss = loss_simple
            log('loss_simple', loss_simple.item(), global_step)

            if criterion_gp is not None:
                loss_gp = criterion_gp(fake_images_denorm, input_images_denorm)
                loss = loss + arg.loss_gp_lambda * loss_gp
                log('loss_gp', loss_gp.item(), global_step)

            if criterion_cp is not None:
                loss_cp = criterion_cp(fake_images_denorm, input_images_denorm)
                loss = loss + arg.loss_cp_lambda * loss_cp
                log('loss_cp', loss_cp.item(), global_step)

            if criterion_feature is not None:
                loss_feature = criterion_feature(fake_images_denorm,
                                                 input_images_denorm)
                loss = loss + arg.loss_feature_lambda * loss_feature

                log('loss_feature', loss_feature.item(), global_step)

            if regressor is not None:
                with torch.no_grad():
                    fake_images = denormalize(fake_images_norm, mean, std)
                    fake_images = rgb_to_grayscale(fake_images)
                    fake_images = normalize(fake_images, mean_gray, std_gray)
                    #TODO fix estimator.forward
                    regressor_out = regressor(fake_images, heatmaps_orig)

                loss_regressor = criterion_regressor(regressor_out,
                                                     gt_coords_xy)
                loss = loss + arg.loss_regressor_lambda * loss_regressor

                log('loss_regressor', loss_regressor.item(), global_step)

            log('loss', loss.item(), global_step)

            loss.backward()
            optimizer_decoder.step()

            sum_loss_decoder += loss.item()

            if arg.save_logs:
                images_to_save = np.uint8(
                    np.clip(
                        denormalize(fake_images_denorm[0, ...].detach(), mean,
                                    std).cpu().numpy(), 0.0, 255.0))
                log_img('fake_image', images_to_save, global_step)

        mean_loss_decoder = sum_loss_decoder / forward_times_per_epoch

        scheduler_decoder.step(mean_loss_decoder)

        if (epoch + 1) % arg.save_interval == 0:
            torch.save(
                decoder.state_dict(), arg.save_folder + 'decoder_' +
                arg.dataset + '_' + arg.split + '_' + str(epoch + 1) + '.pth')

        print('\nepoch: {:0>4d} | loss_decoder: {:.10f}'.format(
            epoch,
            mean_loss_decoder,
        ))

    torch.save(
        decoder.state_dict(), arg.save_folder + 'decoder_' + arg.dataset +
        '_' + arg.split + '_' + str(epoch + 1) + '.pth')
    print('Training done!')
Beispiel #17
0
# training
step = 0
best_psnr_val = 0.
torch.backends.cudnn.benchmark = True
for epoch in range(num_epoch):
    ''' train '''
    for i, (lr, guide, gt) in enumerate(loader['train']):
        lr, guide, gt = lr.cuda(), guide.cuda(), gt.cuda()
        tmp_ind = int(torch.randint(lr.shape[1], [1]))
        lr = torch.nn.functional.interpolate(lr[:, tmp_ind, :, :].unsqueeze(1),
                                             scale_factor=scale,
                                             mode='bicubic',
                                             align_corners=True)
        gt = gt[:, tmp_ind, :, :].unsqueeze(1)
        guide = rgb_to_grayscale(guide)
        #1. update
        net.train()
        net.zero_grad()
        optimizer.zero_grad()
        imgf = net(lr, guide)
        loss = nn.MSELoss()(gt, imgf)
        loss.backward()
        optimizer.step()

        #2.  print
        print("[%d,%d] Loss: %.4f" % (epoch + 1, i + 1, loss.item()))
        #3. Log the scalar values
        writer.add_scalar('loss', loss.item(), step)
        step += 1
    ''' validation '''