Exemplo n.º 1
0
def test_ssim_reduction_and_full(reduction: str, full: bool, expectation: Any,
                                 prediction: torch.Tensor,
                                 target: torch.Tensor, device: str) -> None:
    prediction = prediction.to(device)
    target = target.to(device)
    with expectation:
        ssim(prediction, target, data_range=1., reduction=reduction, full=full)
Exemplo n.º 2
0
def test_ssim_fails_for_incorrect_data_range(x: torch.Tensor, y: torch.Tensor,
                                             device: str) -> None:
    # Scale to [0, 255]
    x_scaled = (x * 255).type(torch.uint8)
    y_scaled = (y * 255).type(torch.uint8)
    with pytest.raises(AssertionError):
        ssim(x_scaled.to(device), y_scaled.to(device), data_range=1.0)
Exemplo n.º 3
0
def test_ssim_simmular_to_matlab_implementation():
    # Greyscale images
    goldhill = torch.tensor(imread('tests/assets/goldhill.gif'))[None, None,
                                                                 ...]
    goldhill_jpeg = torch.tensor(
        imread('tests/assets/goldhill_jpeg.gif'))[None, None, ...]

    score = ssim(goldhill_jpeg, goldhill, data_range=255, reduction='none')
    # Output of http://www.cns.nyu.edu/~lcv/ssim/ssim.m
    score_baseline = torch.tensor(0.8202)

    assert torch.isclose(score, score_baseline, atol=1e-4), \
        f'Expected PyTorch score to be equal to MATLAB prediction. Got {score} and {score_baseline}'

    # RGB images
    I01 = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1)[None,
                                                                        ...]
    i1_01_5 = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(
        2, 0, 1)[None, ...]

    score = ssim(i1_01_5, I01, data_range=255, reduction='none')
    # Output of http://www.cns.nyu.edu/~lcv/ssim/ssim.m
    # score_baseline = torch.tensor(0.7820)
    score_baseline = torch.tensor(0.7842)

    assert torch.isclose(score, score_baseline, atol=1e-2), \
        f'Expected PyTorch score to be equal to MATLAB prediction. Got {score} and {score_baseline}'
Exemplo n.º 4
0
def test_ssim_symmetry(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    measure = ssim(x, y, data_range=1., reduction='none')
    reverse_measure = ssim(y, x, data_range=1., reduction='none')
    assert torch.allclose(measure, reverse_measure), f'Expect: SSIM(a, b) == SSIM(b, a), ' \
                                                     f'got {measure} != {reverse_measure}'
Exemplo n.º 5
0
def test_ssim_reduction(x: torch.Tensor, y: torch.Tensor, device: str) -> None:
    for mode in ['mean', 'sum', 'none']:
        ssim(x.to(device), y.to(device), reduction=mode)

    for mode in [None, 'n', 2]:
        with pytest.raises(KeyError):
            ssim(x.to(device), y.to(device), reduction=mode)
Exemplo n.º 6
0
def test_ssim_reduction(prediction: torch.Tensor, target: torch.Tensor,
                        device: str) -> None:
    for mode in ['mean', 'sum', 'none']:
        ssim(prediction.to(device), target.to(device), reduction=mode)

    for mode in [None, 'n', 2]:
        with pytest.raises(KeyError):
            ssim(prediction.to(device), target.to(device), reduction=mode)
Exemplo n.º 7
0
def test_ssim_symmetry(prediction_target_4d_5d: Tuple[torch.Tensor,
                                                      torch.Tensor],
                       device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    measure = ssim(prediction, target, data_range=1., reduction='none')
    reverse_measure = ssim(target, prediction, data_range=1., reduction='none')
    assert torch.allclose(measure, reverse_measure), f'Expect: SSIM(a, b) == SSIM(b, a), ' \
                                                     f'got {measure} != {reverse_measure}'
Exemplo n.º 8
0
def test_ssim_raises_if_kernel_size_greater_than_image(x_y_4d_5d,
                                                       device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    kernel_size = 11
    wrong_size_x = x[:, :, :kernel_size - 1, :kernel_size - 1]
    wrong_size_y = y[:, :, :kernel_size - 1, :kernel_size - 1]
    with pytest.raises(ValueError):
        ssim(wrong_size_x, wrong_size_y, kernel_size=kernel_size)
Exemplo n.º 9
0
def test_ssim_fails_for_incorrect_data_range(prediction: torch.Tensor,
                                             target: torch.Tensor,
                                             device: str) -> None:
    # Scale to [0, 255]
    prediction_scaled = (prediction * 255).type(torch.uint8)
    target_scaled = (target * 255).type(torch.uint8)
    with pytest.raises(AssertionError):
        ssim(prediction_scaled.to(device),
             target_scaled.to(device),
             data_range=1.0)
Exemplo n.º 10
0
def test_ssim_check_kernel_size_is_passed(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    kernel_sizes = list(range(0, 50))
    for kernel_size in kernel_sizes:
        if kernel_size % 2:
            ssim(x, y, kernel_size=kernel_size)
        else:
            with pytest.raises(AssertionError):
                ssim(x, y, kernel_size=kernel_size)
Exemplo n.º 11
0
def test_ssim_check_kernel_size_is_passed(prediction_target_4d_5d: Tuple[
    torch.Tensor, torch.Tensor], device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    kernel_sizes = list(range(0, 50))
    for kernel_size in kernel_sizes:
        if kernel_size % 2:
            ssim(prediction, target, kernel_size=kernel_size)
        else:
            with pytest.raises(AssertionError):
                ssim(prediction, target, kernel_size=kernel_size)
Exemplo n.º 12
0
def test_ssim_raises_if_kernel_size_greater_than_image(
        prediction_target_4d_5d: Tuple[torch.Tensor,
                                       torch.Tensor], device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    kernel_size = 11
    wrong_size_prediction = prediction[:, :, :kernel_size - 1, :kernel_size -
                                       1]
    wrong_size_target = target[:, :, :kernel_size - 1, :kernel_size - 1]
    with pytest.raises(ValueError):
        ssim(wrong_size_prediction, wrong_size_target, kernel_size=kernel_size)
Exemplo n.º 13
0
def test_ssim_check_available_dimensions() -> None:
    custom_x = torch.rand(256, 256)
    custom_y = torch.rand(256, 256)
    for _ in range(10):
        if custom_x.dim() < 5:
            try:
                ssim(custom_x, custom_y)
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                ssim(custom_x, custom_y)
        custom_x.unsqueeze_(0)
        custom_y.unsqueeze_(0)
Exemplo n.º 14
0
def test_ssim_supports_different_data_ranges(input_tensors: Tuple[
    torch.Tensor, torch.Tensor], data_range, device: str) -> None:
    x, y = input_tensors
    x_scaled = (x * data_range).type(torch.uint8)
    y_scaled = (y * data_range).type(torch.uint8)

    measure_scaled = ssim(x_scaled.to(device),
                          y_scaled.to(device),
                          data_range=data_range)
    measure = ssim(x_scaled.to(device) / float(data_range),
                   y_scaled.to(device) / float(data_range),
                   data_range=1.0)
    diff = torch.abs(measure_scaled - measure)
    assert diff <= 1e-6, f'Result for same tensor with different data_range should be the same, got {diff}'
Exemplo n.º 15
0
def test_ssim_raises_if_tensors_have_different_shapes(x_y_4d_5d,
                                                      device) -> None:
    y = x_y_4d_5d[1].to(device)
    dims = [[3], [2, 3], [161, 162], [161, 162]]
    if y.dim() == 5:
        dims += [[2, 3]]
    for size in list(itertools.product(*dims)):
        wrong_shape_x = torch.rand(size).to(y)
        if wrong_shape_x.size() == y.size():
            try:
                ssim(wrong_shape_x, y)
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                ssim(wrong_shape_x, y)
Exemplo n.º 16
0
    def eval(self, gt, pred):
        with torch.no_grad():
            gt_tensor = torch.Tensor(gt).clamp(0, 1).permute(0, 3, 1,
                                                             2).to('cuda:0')
            pred_tensor = torch.Tensor(pred).clamp(0,
                                                   1).permute(0, 3, 1,
                                                              2).to('cuda:0')
            psnr_index = piq.psnr(pred_tensor,
                                  gt_tensor,
                                  data_range=1.,
                                  reduction='none').item()
            _, _, h, w = gt_tensor.shape

            lpipsAlex = 0
            lpipsVGG = 0
            msssim_index = 0
            ssim_index = 0
            n = 1
            for i in range(n):
                for j in range(n):
                    xstart = w // n * j
                    ystart = h // n * i
                    xend = w // n * (j + 1)
                    yend = h // n * (i + 1)
                    ssim_index += piq.ssim(pred_tensor[:, :, ystart:yend,
                                                       xstart:xend],
                                           gt_tensor[:, :, ystart:yend,
                                                     xstart:xend],
                                           data_range=1.,
                                           reduction='none').item()
                    msssim_index = piq.multi_scale_ssim(
                        pred_tensor[:, :, ystart:yend, xstart:xend],
                        gt_tensor[:, :, ystart:yend, xstart:xend],
                        data_range=1.,
                        reduction='none').item()
                    lpipsVGG += self.lpipsVGG(
                        pred_tensor[:, :, ystart:yend, xstart:xend],
                        gt_tensor[:, :, ystart:yend, xstart:xend]).item()
                    lpipsAlex += self.lpipsAlex(
                        pred_tensor[:, :, ystart:yend, xstart:xend],
                        gt_tensor[:, :, ystart:yend, xstart:xend]).item()
            msssim_index /= n * n
            ssim_index /= n * n
            lpipsVGG /= n * n
            lpipsAlex /= n * n
            # dists = piq.DISTS(reduction='none')(pred_tensor, gt_tensor).item()

            # with torch.no_grad():
            #     lpips_index = piq.LPIPS(reduction='none')(pred_tensor, gt_tensor).item()
            rmse = ((gt - pred)**2).mean()**0.5
            # relmse = (((gt - pred) ** 2).mean() / (gt ** 2).mean() + 1e-5) ** 0.5
            # return {'rmse':rmse,'relmse':relmse,'psnr':psnr_index,'ssim':ssim_index,'msssim':msssim_index,'lpips':lpips_index}
        return {
            'rmse': rmse,
            'psnr': psnr_index,
            'ssim': ssim_index,
            'msssim': msssim_index,
            'lpipsVGG': lpipsVGG,
            'lpipsAlex': lpipsAlex
        }
    def update(self, output):
        y_pred = output[0]
        y = output[1]
        y_pred = torch.clamp_min(y_pred, min=0.0)
        y = torch.clamp_min(y, min=0.0)
        # print("CrowdCountingMeanSSIMclamp ")
        # print("y_pred", y_pred.shape)
        # print("y", y.shape)

        y_pred = F.interpolate(y_pred, scale_factor=8) / 64
        pad_density_map_tensor = torch.zeros((1, 1, y.shape[2], y.shape[3])).cuda()
        pad_density_map_tensor[:, 0, :y_pred.shape[2], :y_pred.shape[3]] = y_pred
        y_pred = pad_density_map_tensor

        # y_max = torch.max(y)
        # y_pred_max = torch.max(y_pred)
        # max_value = torch.max(y_max, y_pred_max)

        y = y / torch.max(y) * 255
        y_pred = y_pred / torch.max(y_pred) * 255

        ssim_metric = piq.ssim(y, y_pred, reduction="sum", data_range=255)

        self._sum += ssim_metric.item()
        # we multiply because ssim calculate mean of each image in batch
        # we multiply so we will divide correctly

        self._num_examples += y.shape[0]
Exemplo n.º 18
0
 def val_iter(self, final=True):
     with torch.no_grad():
         self.model.eval()
         t = tqdm(self.loader_val)
         if final:
             t.set_description("Validation")
         else:
             t.set_description(f"Epoch {self.epoch} val   ")
         psnr_avg = AverageMeter()
         ssim_avg = AverageMeter()
         l1_avg = AverageMeter()
         l2_avg = AverageMeter()
         for hr, lr in t:
             hr, lr = hr.to(self.dtype).to(self.device), lr.to(self.dtype).to(self.device)
             sr = self.model(lr).clamp(0, 1)
             l1_loss = torch.nn.functional.l1_loss(sr, hr).item()
             l2_loss = torch.sqrt(torch.nn.functional.mse_loss(sr, hr)).item()
             psnr = piq.psnr(hr, sr)
             ssim = piq.ssim(hr, sr)
             l1_avg.update(l1_loss)
             l2_avg.update(l2_loss)
             psnr_avg.update(psnr)
             ssim_avg.update(ssim)
             t.set_postfix(PSNR=f'{psnr_avg.get():.2f}', SSIM=f'{ssim_avg.get():.4f}')
         if self.writer is not None:
             self.writer.add_scalar('PSNR', psnr_avg.get(), self.epoch)
             self.writer.add_scalar('SSIM', ssim_avg.get(), self.epoch)
             self.writer.add_scalar('L1', l1_avg.get(), self.epoch)
             self.writer.add_scalar('L2', l2_avg.get(), self.epoch)
         return psnr_avg.get(), ssim_avg.get()
Exemplo n.º 19
0
    def forward(self,
                reference_observations: torch.Tensor,
                generated_observations: torch.Tensor,
                range=1.0) -> torch.Tensor:
        '''
        Computes the ssim between the reference and the generated observations

        :param reference_observations: (bs, observations_count, channels, height, width) tensor with reference observations
        :param generated_observations: (bs, observations_count, channels, height, width) tensor with generated observations
        :param range: The maximum value used to represent each pixel
        :return: (bs, observations_count) tensor with ssim for each observation
        '''

        # Flattens observations and then folds the results
        observations_count = reference_observations.size(1)
        flattened_reference_observations = TensorFolder.flatten(
            reference_observations)
        flattened_generated_observations = TensorFolder.flatten(
            generated_observations)
        flattened_ssim = ssim(flattened_generated_observations,
                              flattened_reference_observations,
                              range,
                              reduction="none")
        folded_ssim = TensorFolder.fold(flattened_ssim, observations_count)

        return folded_ssim
Exemplo n.º 20
0
def test_ssim_measure_is_less_or_equal_to_one(
        ones_zeros_4d_5d: Tuple[torch.Tensor,
                                torch.Tensor], device: str) -> None:
    # Create two maximally different tensors.
    ones = ones_zeros_4d_5d[0].to(device)
    zeros = ones_zeros_4d_5d[1].to(device)
    measure = ssim(ones, zeros, data_range=1., reduction='none')
    assert torch.le(measure, 1).all(), f'SSIM must be <= 1, got {measure}'
Exemplo n.º 21
0
def test_ssim_measure_is_one_for_equal_tensors(y: torch.Tensor,
                                               device: str) -> None:
    y = y.to(device)
    x = y.clone()
    measure = ssim(x, y, data_range=1., reduction='none')
    assert torch.allclose(measure, torch.ones_like(measure)), f'If equal tensors are passed SSIM must be equal to 1 ' \
                                                              f'(considering floating point error up to 1 * 10^-6), '\
                                                              f'got {measure}'
Exemplo n.º 22
0
def test_ssim_raises_if_tensors_have_different_shapes(
        prediction_target_4d_5d: Tuple[torch.Tensor,
                                       torch.Tensor], device) -> None:
    target = prediction_target_4d_5d[1].to(device)
    dims = [[3], [2, 3], [161, 162], [161, 162]]
    if target.dim() == 5:
        dims += [[2, 3]]
    for size in list(itertools.product(*dims)):
        wrong_shape_prediction = torch.rand(size).to(target)
        if wrong_shape_prediction.size() == target.size():
            try:
                ssim(wrong_shape_prediction, target)
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                ssim(wrong_shape_prediction, target)
Exemplo n.º 23
0
    def eval(self,gt,pred,imformat='BHWC',dtype='jax'):
        if(dtype == 'jax'):
            gt = np.array(gt)
            pred = np.array(pred)

        with torch.no_grad():
            if(imformat == 'BHWC'):
                gt_tensor = torch.Tensor(gt).permute(0,3,1,2).to(self.device)
                pred_tensor = torch.Tensor(pred).permute(0,3,1,2).to(self.device)
            elif(imformat == 'HWC'):
                gt_tensor = torch.Tensor(gt[None,...]).permute(0,3,1,2).to(self.device)
                pred_tensor = torch.Tensor(pred[None,...]).permute(0,3,1,2).to(self.device)
            else:
                print('Unknown image dimension format')
                exit(0)

            pred_tensor = torch.clamp(pred_tensor,0,1)
            gt_tensor = torch.clamp(gt_tensor,0,1)
            _,_,h,w = gt_tensor.shape
            
            lpipsAlex = 0
            lpipsVGG = 0
            msssim_index = 0
            ssim_index = 0
            n = 1
            for i in range(n):
                for j in range(n):
                    xstart = w//n * j
                    ystart = h//n * i
                    xend = w//n * (j+1)
                    yend = h//n * (i+1)
                    if('ssim' in self.metrics):
                        ssim_index += piq.ssim(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend], data_range=1., reduction='mean').item()
                    if('msssim' in self.metrics):
                        msssim_index = piq.multi_scale_ssim(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend], data_range=1., reduction='mean').item()
                    if('lpipsVGG' in self.metrics):
                        lpipsVGG +=  self.lpipsVGG(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend]).item()
                    if('lpipsAlex' in self.metrics):
                        lpipsAlex +=  self.lpipsAlex(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend]).item()

            res = {}
            if('ssim' in self.metrics):
                res['ssim'] = ssim_index / n*n
            if('msssim' in self.metrics):
                res['msssim'] = msssim_index / n*n
            if('lpipsVGG' in self.metrics):
                res['lpipsVGG'] = lpipsVGG / n*n
            if('lpipsAlex' in self.metrics):
                res['lpipsAlex'] = lpipsAlex / n*n
            if('rmse' in self.metrics):
                res['rmse'] = ((gt_tensor - pred_tensor) ** 2).mean() ** 0.5
            
            res['mse'] = float(((gt_tensor - pred_tensor) ** 2).mean().cpu().numpy())
            res['psnr'] = -10. * np.log10(res['mse']) / np.log10(10.)
            
        return res
Exemplo n.º 24
0
    def validation_epoch_end(
            self, outputs: List[Tuple[torch.Tensor, torch.Tensor]]) -> None:
        if isinstance(self.val_dataloader().dataset, ImageLoader):
            self.val_dataloader().dataset.val = False
        else:
            self.val_dataloader().dataset.dataset.val = False

        fid_score = fid(self.forged_images, self.reference_images,
                        self.hparams.feature_dimensionality_fid, self.device)
        ssim_score = ssim(self.forged_images,
                          self.reference_images,
                          data_range=255)
        psnr_score = psnr(self.forged_images,
                          self.reference_images,
                          data_range=255)

        self.log('FID_score', fid_score, on_step=False, on_epoch=True)
        self.log('SSIM', ssim_score, on_step=False, on_epoch=True)
        self.log('PSNR', psnr_score, on_step=False, on_epoch=True)
def image_metrics_from_dataset(dataset, output_addr='/tmp/psnr.csv'):

    cols = ['file_name', 'psnr', 'ssim']
    cols_str = ','.join(cols)

    with open(output_addr, 'wt') as f:
        f.write(f'{cols_str}\n')

    dl = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
    for i, data in tqdm(enumerate(dl), total=int(len(dataset))):
        x, y, file_name = data
        psnr = piq.psnr(x, y).item()
        ssim = piq.ssim(x, y).item()
        vals = [file_name[0], str(psnr), str(ssim)]
        val_str = ','.join(vals)
        with open(output_addr, 'at') as f:
            f.write(f'{val_str}\n')

    return pd.read_csv(output_addr)
Exemplo n.º 26
0
def usage():
    import torch
    from piq import ssim, SSIMLoss

    x = torch.rand(4, 256, 256, requires_grad=True)
    y = torch.rand(4, 256, 256)

    ssim_index: torch.Tensor = ssim(x, y, data_range=1.)
    ssimvalue = ssim_index.detach().numpy()
    assert type(ssim_index.detach().numpy()) == np.ndarray, type(
        ssim_index.detach().numpy())
    print(f"ssim_index: ", )

    # loss = SSIMLoss(data_range=1.)
    loss = SSIMLoss(data_range=1.)
    output2 = loss(x, y)
    output: torch.Tensor = loss(x, x)
    print(output.item())
    output.backward()
Exemplo n.º 27
0
def test_ssim_raise_if_wrong_value_is_estimated(
        test_images: Tuple[torch.Tensor, torch.Tensor], device: str) -> None:
    for x, y in test_images:
        piq_ssim = ssim(x.to(device),
                        y.to(device),
                        kernel_size=11,
                        kernel_sigma=1.5,
                        data_range=255,
                        reduction='none')
        tf_x = tf.convert_to_tensor(x.permute(0, 2, 3, 1).numpy())
        tf_y = tf.convert_to_tensor(y.permute(0, 2, 3, 1).numpy())
        with tf.device('/CPU'):
            tf_ssim = torch.tensor(
                tf.image.ssim(tf_x, tf_y, max_val=255).numpy()).to(piq_ssim)

        match_accuracy = 2e-4 + 1e-8
        assert torch.allclose(piq_ssim, tf_ssim, rtol=0, atol=match_accuracy), \
            f'The estimated value must be equal to tensorflow provided one' \
            f'(considering floating point operation error up to {match_accuracy}), ' \
            f'got difference {(piq_ssim - tf_ssim).abs()}'
Exemplo n.º 28
0
def grid_search(x, y, rec_func, grid):
    """ Grid search utility for tuning hyper-parameters. """
    err_min = np.inf
    grid_param = None

    grid_shape = [len(val) for val in grid.values()]
    err = torch.zeros(grid_shape)
    err_psnr = torch.zeros(grid_shape)
    err_ssim = torch.zeros(grid_shape)

    for grid_val, nidx in zip(itertools.product(*grid.values()),
                              np.ndindex(*grid_shape)):
        grid_param_cur = dict(zip(grid.keys(), grid_val))
        print(
            "Current grid parameters (" + str(list(nidx)) + " / " +
            str(grid_shape) + "): " + str(grid_param_cur),
            flush=True,
        )
        x_rec = rec_func(y, **grid_param_cur)
        err[nidx], _ = l2_error(x_rec, x, relative=True, squared=False)
        err_psnr[nidx] = psnr(
            rotate_real(x_rec)[:, 0:1, ...],
            rotate_real(x)[:, 0:1, ...],
            data_range=rotate_real(x)[:, 0:1, ...].max(),
            reduction="mean",
        )
        err_ssim[nidx] = ssim(
            rotate_real(x_rec)[:, 0:1, ...],
            rotate_real(x)[:, 0:1, ...],
            data_range=rotate_real(x)[:, 0:1, ...].max(),
            size_average=True,
        )
        print("Rel. recovery error: {:1.2e}".format(err[nidx]), flush=True)
        print("PSNR: {:.2f}".format(err_psnr[nidx]), flush=True)
        print("SSIM: {:.2f}".format(err_ssim[nidx]), flush=True)
        if err[nidx] < err_min:
            grid_param = grid_param_cur
            err_min = err[nidx]

    return grid_param, err_min, err, err_psnr, err_ssim
Exemplo n.º 29
0
 def forward(self, predict, target):
     if self.l1_norm:
         l1_norm_metric = nn.functional.l1_loss(predict, target)
     if self.mse:
         mse_norm_metric = nn.functional.mse_loss(predict, target)
     if self.pearsonr:
         pearsonr_metric = audtorch.metrics.functional.pearsonr(predict, target).mean()
     if self.cc:
         cc_metric = audtorch.metrics.functional.concordance_cc(predict, target).mean()
     if self.psnr:
         psnr_metric = piq.psnr(predict, target, data_range=1., reduction='none').mean()
     if self.ssim:
         ssim_metric = piq.ssim(predict, target, data_range=1.)
     if self.mssim:
         mssim_metric = piq.multi_scale_ssim(predict, target, data_range=1.)
     metric_summary = {'l1_norm': l1_norm_metric,
                       'mse': mse_norm_metric,
                       'pearsonr_metric': pearsonr_metric,
                       'cc': cc_metric,
                       'psnr': psnr_metric,
                       'ssim': ssim_metric,
                       'mssim': mssim_metric
                       }
     return metric_summary
Exemplo n.º 30
0
def main(args):

    input_shape = (3, 380, 380)
    if not os.path.exists(args.checkpoints_output):
        os.makedirs(args.checkpoints_output)

    if not os.path.exists(args.logs):
        os.makedirs(args.logs)

    images_output = os.path.join(args.logs, 'images')
    if not os.path.exists(images_output):
        os.makedirs(images_output)

    if not args.model in models:
        print(f"Model name {args.model} must be one of: {model_names}")
        return 1

    print(f"Seting up training for model: {args.model}")
    print(f"Train X Root: {args.train_x_root}")
    print(f"Train Y Root: {args.train_y_root}")

    if args.test_x is not None and args.test_y is not None:
        print(f"Test X Root: {args.test_x}")
        print(f"Test Y Root: {args.test_y}")

    normalize_transform = transforms.Normalize((0.5, 0.5, 0.5),
                                               (0.5, 0.5, 0.5))
    if args.test_x is None or args.test_y is None:
        dataset = EnumPairedDataset(args.train_x_root,
                                    args.train_y_root,
                                    transform=normalize_transform)
        train_d, test_d = train_val_dataset(dataset)
    else:
        train_d = EnumPairedDataset(args.train_x_root,
                                    args.train_y_root,
                                    transform=normalize_transform)
        test_d = EnumPairedDataset(args.test_x_root,
                                   args.test_y_root,
                                   transform=normalize_transform)

    train_batch_size = args.train_batch_size
    test_batch_size = args.test_batch_size
    train_dl = DataLoader(train_d,
                          batch_size=train_batch_size,
                          shuffle=True,
                          num_workers=0)
    test_dl = DataLoader(test_d,
                         batch_size=test_batch_size,
                         shuffle=True,
                         num_workers=0)

    if args.show_dataset:
        x_batch, y_batch, names = next(iter(train_dl))
        plt.subplot(2, 1, 1)
        plt.imshow(torchvision.utils.make_grid(x_batch).permute(1, 2, 0))
        plt.subplot(2, 1, 2)
        plt.imshow(torchvision.utils.make_grid(y_batch).permute(1, 2, 0))
        plt.show()

    model = models[args.model]

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    device = f"cuda:{model.device_ids[0]}"
    #device = 'cpu'
    model.to(device)
    summary(model, input_shape)

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    #criterion = loss_fun()

    if args.load is not None:
        try:
            pretrained_dict = torch.load(args.load)
            model_dict = model.state_dict()
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)
            print(f"Weights loaded from {args.load}")
        except Exception as e:
            print(f"Couldn't load weights from {args.load}")
            print(e)

    best_training_loss = math.inf
    test_loss_for_best_training_loss = math.inf

    cols = [
        'epoch', 'training_loss', 'test_loss', 'train_psnr', 'test_psnr',
        'train_ssim', 'test_ssim'
    ]
    logs_addr = os.path.join(args.logs, 'logs.csv')
    add_line_to_csv(logs_addr, cols)

    print(f"Logs to: {args.logs}")
    epochs = args.epochs
    if epochs <= args.from_epoch:
        epochs += args.from_epoch

    for epoch in range(args.from_epoch, epochs + 1):

        print(f"Epoch {epoch}/{epochs}")

        training_loss = 0.0
        test_loss = 0.0

        train_psnr = 0.0
        test_psnr = 0.0
        train_ssim = 0.0
        test_ssim = 0.0

        if args.load is not None:
            try:
                fn, ext = os.path.splitext(os.path.basename(args.load))
                loss_vals = fn.split('_')
                best_training_loss = float(loss_vals[2])
                test_loss_for_best_training_loss = float(loss_vals[3])
            except Exception as e:
                print(f"Couldn't load best training loss from {args.load}")
                print(e)

        print("Training:")
        for i, data in tqdm(enumerate(train_dl),
                            total=int(len(train_d) / train_batch_size)):
            w, m, file_name = data
            x = w.to(device)
            y = m.to(device)
            del w
            del m

            optimizer.zero_grad()
            y_hat = model(x)

            loss = loss_fun(y_hat, y)
            loss.backward()
            optimizer.step()

            training_loss += float(loss.item())
            del x
            '''
            train_psnr = piq.psnr(y_hat[0], y[0],data_range=1.,
                                   reduction='none')
            train_ssim = piq.ssim(y_hat[0], y[0], data_range=1.,
                                   reduction='none')
            '''

            del y
            del y_hat

        training_loss /= (i + 1)
        #train_psnr /= (i+1)
        #train_ssim /= (i+1)

        with torch.no_grad():
            print("Testing:")
            for i, data in tqdm(enumerate(test_dl),
                                total=int(len(test_d) / test_batch_size)):
                w, m, file_name = data
                x = w.to(device)
                y = m.to(device)
                y_hat = model(x)
                loss = loss_fun(y_hat, y)
                test_loss += float(loss.item())
                del x

                try:
                    test_psnr += piq.psnr(y_hat, y)
                    test_ssim += piq.ssim(y_hat, y)
                except:
                    pass

                if args.show_output_images and i < 5:
                    imgs_dir = os.path.join(images_output, f"epoch_{epoch}")
                    if not os.path.exists(imgs_dir):
                        os.makedirs(imgs_dir)
                    for j, y_hat_i in enumerate(y_hat):
                        fn = os.path.splitext(os.path.basename(
                            file_name[j]))[0]
                        y_gt = y[j]
                        img_i_addr = os.path.join(
                            imgs_dir,
                            f'{epoch}_{fn}_{i}_{j}.{dataset.images_extension}')
                        img_i_gt_addr = os.path.join(
                            imgs_dir,
                            f'{epoch}_{fn}_{i}_{j}_gt.{dataset.images_extension}'
                        )
                        torchvision.utils.save_image(y_hat_i, img_i_addr)
                        torchvision.utils.save_image(y_gt, img_i_gt_addr)
                        del y_gt
                del y
                del y_hat

        test_loss /= (i + 1)
        test_psnr /= (i + 1)
        test_ssim /= (i + 1)

        print(f"Completed Epoch: {epoch}/{args.epochs}")
        print(f"\tTrain loss: {training_loss}")
        print(f"\tTest loss: {test_loss}")
        print(f"\tTrain PSNR: {train_psnr}")
        print(f"\tTest PSNR: {test_psnr}")
        print(f"\tTrain SSIM: {train_ssim}")
        print(f"\tTest SSIM: {test_ssim}")
        print(f"\tBest training loss so far: {best_training_loss}")
        print(f"\tTest loss for: {test_loss_for_best_training_loss}")

        add_line_to_csv(logs_addr, [
            str(epoch),
            str(training_loss),
            str(test_loss),
            str(train_psnr),
            str(test_psnr),
            str(train_ssim),
            str(test_ssim)
        ])

        if best_training_loss > training_loss:
            best_training_loss = training_loss
            test_loss_for_best_training_loss = test_loss
            save_file_name = f"{args.model}_epoch_{epoch}_{best_training_loss:.3f}_{test_loss_for_best_training_loss:.3f}.pth"
            checkpoint_path = os.path.join(args.checkpoints_output,
                                           save_file_name)
            torch.save(model.state_dict(), checkpoint_path)