コード例 #1
0
def test_multi_scale_ssim_loss_grad(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    x.requires_grad_()
    loss = MultiScaleSSIMLoss(data_range=1.)(x, y).mean()
    loss.backward()
    assert torch.isfinite(x.grad).all(), f'Expected finite gradient values, got {x.grad}'
コード例 #2
0
ファイル: test_ssim.py プロジェクト: akamaus/piq
def test_multi_scale_ssim_loss_equality(target: torch.Tensor,
                                        device: str) -> None:
    target = target.to(device)
    prediction = target.clone()
    loss = MultiScaleSSIMLoss()(prediction, target)
    assert (loss.abs() <= 1e-6).all(), f'If equal tensors are passed SSIM loss must be equal to 0 ' \
                                       f'(considering floating point operation error up to 1 * 10^-6), got {loss}'
コード例 #3
0
def test_multi_scale_ssim_loss_raises_if_tensors_have_different_types(
        x, y) -> None:
    wrong_type_y = list(range(10))
    with pytest.raises(AssertionError):
        MultiScaleSSIMLoss()(wrong_type_y, y)
    wrong_type_scale_weights = True
    with pytest.raises(AssertionError):
        MultiScaleSSIMLoss(scale_weights=wrong_type_scale_weights)(x, y)
コード例 #4
0
ファイル: test_ssim.py プロジェクト: akamaus/piq
def test_multi_scale_ssim_loss_raises_if_wrong_kernel_size_is_passed(
        prediction: torch.Tensor, target: torch.Tensor) -> None:
    kernel_sizes = list(range(0, 13))
    for kernel_size in kernel_sizes:
        if kernel_size % 2:
            MultiScaleSSIMLoss(kernel_size=kernel_size)(prediction, target)
        else:
            with pytest.raises(AssertionError):
                MultiScaleSSIMLoss(kernel_size=kernel_size)(prediction, target)
コード例 #5
0
ファイル: test_ssim.py プロジェクト: akamaus/piq
def test_multi_scale_ssim_loss_raises_if_tensors_have_different_types(
        prediction: torch.Tensor, target: torch.Tensor) -> None:
    wrong_type_prediction = list(range(10))
    with pytest.raises(AssertionError):
        MultiScaleSSIMLoss()(wrong_type_prediction, target)
    wrong_type_scale_weights = True
    with pytest.raises(AssertionError):
        MultiScaleSSIMLoss(scale_weights=wrong_type_scale_weights)(prediction,
                                                                   target)
コード例 #6
0
def test_multi_scale_ssim_loss_raises_if_wrong_kernel_size_is_passed(
        x, y) -> None:
    kernel_sizes = list(range(0, 13))
    for kernel_size in kernel_sizes:
        if kernel_size % 2:
            MultiScaleSSIMLoss(kernel_size=kernel_size)(x, y)
        else:
            with pytest.raises(AssertionError):
                MultiScaleSSIMLoss(kernel_size=kernel_size)(x, y)
コード例 #7
0
ファイル: test_ssim.py プロジェクト: akamaus/piq
def test_multi_scale_ssim_loss_grad(prediction_target_4d_5d: Tuple[
    torch.Tensor, torch.Tensor], device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    prediction.requires_grad_()
    loss = MultiScaleSSIMLoss(data_range=1.)(prediction, target).mean()
    loss.backward()
    assert torch.isfinite(prediction.grad).all(
    ), f'Expected finite gradient values, got {prediction.grad}'
コード例 #8
0
def test_multi_scale_ssim_loss_check_available_dimensions() -> None:
    custom_x = torch.rand(256, 256)
    custom_y = torch.rand(256, 256)
    for _ in range(10):
        if custom_x.dim() < 5:
            try:
                MultiScaleSSIMLoss()(custom_x, custom_y)
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                MultiScaleSSIMLoss()(custom_x, custom_y)
        custom_x.unsqueeze_(0)
        custom_y.unsqueeze_(0)
コード例 #9
0
def test_multi_scale_ssim_loss_is_less_or_equal_to_one(ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor],
                                                       device: str) -> None:
    # Create two maximally different tensors.
    ones = ones_zeros_4d_5d[0].to(device)
    zeros = ones_zeros_4d_5d[1].to(device)
    loss = MultiScaleSSIMLoss()(ones, zeros)
    assert (loss <= 1).all(), f'MS-SSIM loss must be <= 1, got {loss}'
コード例 #10
0
def test_multi_scale_ssim_loss_raises_if_tensors_have_different_shapes(x_y_4d_5d, device: str) -> None:
    y = x_y_4d_5d[1].to(device)
    dims = [[3], [2, 3], [161, 162], [161, 162]]
    if y.dim() == 5:
        dims += [[2, 3]]
    for size in list(itertools.product(*dims)):
        wrong_shape_x = torch.rand(size).to(y)
        if wrong_shape_x.size() == y.size():
            MultiScaleSSIMLoss()(wrong_shape_x, y)
        else:
            with pytest.raises(AssertionError):
                MultiScaleSSIMLoss()(wrong_shape_x, y)

    scale_weights = torch.rand(2, 2)
    with pytest.raises(AssertionError):
        MultiScaleSSIMLoss(scale_weights=scale_weights)(x, y)
コード例 #11
0
def test_multi_scale_ssim_loss_symmetry(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    loss = MultiScaleSSIMLoss()
    loss_value = loss(x, y)
    reverse_loss_value = loss(y, x)
    assert (loss_value == reverse_loss_value).all(), \
        f'Expect: MS-SSIM(a, b) == MS-SSIM(b, a), got {loss_value} != {reverse_loss_value}'
コード例 #12
0
ファイル: test_ssim.py プロジェクト: akamaus/piq
def test_multi_scale_ssim_loss_raises_if_tensors_have_different_shapes(
        prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor],
        device: str) -> None:
    target = prediction_target_4d_5d[1].to(device)
    dims = [[3], [2, 3], [161, 162], [161, 162]]
    if target.dim() == 5:
        dims += [[2, 3]]
    for size in list(itertools.product(*dims)):
        wrong_shape_prediction = torch.rand(size).to(target)
        if wrong_shape_prediction.size() == target.size():
            MultiScaleSSIMLoss()(wrong_shape_prediction, target)
        else:
            with pytest.raises(AssertionError):
                MultiScaleSSIMLoss()(wrong_shape_prediction, target)
    scale_weights = torch.rand(2, 2)
    with pytest.raises(AssertionError):
        MultiScaleSSIMLoss(scale_weights=scale_weights)(prediction, target)
コード例 #13
0
ファイル: test_ssim.py プロジェクト: akamaus/piq
def test_multi_scale_ssim_loss_symmetry(prediction_target_4d_5d: Tuple[
    torch.Tensor, torch.Tensor], device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    loss = MultiScaleSSIMLoss()
    loss_value = loss(prediction, target)
    reverse_loss_value = loss(target, prediction)
    assert (loss_value == reverse_loss_value).all(), \
        f'Expect: MS-SSIM(a, b) == MS-SSIM(b, a), got {loss_value} != {reverse_loss_value}'
コード例 #14
0
def test_ms_ssim_loss_raises_if_kernel_size_greater_than_image(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    kernel_size = 11
    levels = 5
    min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1
    wrong_size_x = x[:, :, :min_size - 1, :min_size - 1]
    wrong_size_y = y[:, :, :min_size - 1, :min_size - 1]
    with pytest.raises(ValueError):
        MultiScaleSSIMLoss(kernel_size=kernel_size)(wrong_size_x, wrong_size_y)
コード例 #15
0
ファイル: test_ssim.py プロジェクト: akamaus/piq
def test_ms_ssim_loss_raises_if_kernel_size_greater_than_image(
        prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor],
        device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    kernel_size = 11
    levels = 5
    min_size = (kernel_size - 1) * 2**(levels - 1) + 1
    wrong_size_prediction = prediction[:, :, :min_size - 1, :min_size - 1]
    wrong_size_target = target[:, :, :min_size - 1, :min_size - 1]
    with pytest.raises(ValueError):
        MultiScaleSSIMLoss(kernel_size=kernel_size)(wrong_size_prediction,
                                                    wrong_size_target)
コード例 #16
0
def test_multi_scale_ssim_loss_raise_if_wrong_value_is_estimated(
        test_images: List, scale_weights: List, device: str) -> None:
    for x, y in test_images:
        piq_loss = MultiScaleSSIMLoss(kernel_size=11,
                                      kernel_sigma=1.5,
                                      data_range=255,
                                      scale_weights=scale_weights)
        piq_ms_ssim_loss = piq_loss(x.to(device), y.to(device))
        tf_x = tf.convert_to_tensor(x.permute(0, 2, 3, 1).numpy())
        tf_y = tf.convert_to_tensor(y.permute(0, 2, 3, 1).numpy())
        with tf.device('/CPU'):
            tf_ms_ssim = torch.tensor(
                tf.image.ssim_multiscale(
                    tf_x, tf_y, power_factors=scale_weights,
                    max_val=255).numpy()).mean().to(device)
        match_accuracy = 1e-5 + 1e-8
        assert torch.isclose(piq_ms_ssim_loss, 1. - tf_ms_ssim, rtol=0, atol=match_accuracy), \
            f'The estimated value must be equal to tensorflow provided one' \
            f'(considering floating point operation error up to {match_accuracy}), ' \
            f'got difference {(piq_ms_ssim_loss - 1. + tf_ms_ssim).abs()}'
コード例 #17
0
    "out_channels": 2,
}
subnet = UNet

it_net_params = {
    "num_iter": 8,
    "lam": [0.5, 0.6, 0.7, 0.8, 0.9, 0.9, 0.8, 0.15],
    "lam_learnable": True,
    "final_dc": True,
    "resnet_factor": 1.0,
    "concat_mask": False,
    "multi_slice": False,
}

# ----- training configuration -----
msssimloss = MultiScaleSSIMLoss(data_range=1e1)
maeloss = torch.nn.L1Loss(reduction="mean")


def msssim_l1_loss(pred, tar):
    return 0.7 * msssimloss(
        rotate_real(pred)[:, 0:1, ...],
        rotate_real(tar)[:, 0:1, ...],
    ) + 0.3 * maeloss(
        rotate_real(pred)[:, 0:1, ...],
        rotate_real(tar)[:, 0:1, ...],
    )


train_params = {
    "num_epochs": [8],
コード例 #18
0
def test_multi_scale_ssim_loss_equality(y, device: str) -> None:
    y = y.to(device)
    x = y.clone()
    loss = MultiScaleSSIMLoss()(x, y)
    assert (loss.abs() <= 1e-6).all(), f'If equal tensors are passed SSIM loss must be equal to 0 ' \
                                       f'(considering floating point operation error up to 1 * 10^-6), got {loss}'
コード例 #19
0
ファイル: train.py プロジェクト: vvvityaaa/mcmrirecon
def train_net(params):
    # Initialize Parameters
    params = DotDict(params)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    verbose = {}
    verbose['loss_train'], verbose['loss_valid'], verbose['psnr_train'], verbose['psnr_valid'], \
        verbose['ssim_train'], verbose['ssim_valid'], verbose['vif_train'], verbose['vif_valid'] = ([] for i in range(8))

    log_metrics = True
    ssim_module = SSIM()
    msssim_module = MSSSIM()
    vifLoss = VIFLoss(sigma_n_sq=0.4, data_range=1.)
    msssimLoss = MultiScaleSSIMLoss(data_range=1.)
    best_validation_metrics = 100

    train_generator, val_generator = data_loaders(params)
    loaders = {"train": train_generator, "valid": val_generator}

    wnet_identifier = params.mask_URate[0:2] + "WNet_dense=" + str(int(params.dense)) + "_" + params.architecture + "_" \
                      + params.lossFunction + '_lr=' + str(params.lr) + '_ep=' + str(params.num_epochs) + '_complex=' \
                      + str(int(params.complex_net)) + '_' + 'edgeModel=' + str(int(params.edge_model)) \
                      + '(' + str(params.num_edge_slices) + ')_date=' + (datetime.now()).strftime("%d-%m-%Y_%H-%M-%S")

    if not os.path.isdir(params.model_save_path):
        os.mkdir(params.model_save_path)
    print("\n\nModel will be saved at:\n", params.model_save_path)
    print("WNet ID: ", wnet_identifier)

    wnet, optimizer, best_validation_loss, preTrainedEpochs = generate_model(
        params, device)

    # data = (iter(train_generator)).next()

    # Adding writer for tensorboard. Also start tensorboard, which tries to access logs in the runs directory
    writer = init_tensorboard(iter(train_generator), wnet, wnet_identifier,
                              device)

    for epoch in trange(preTrainedEpochs, params.num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train':
                wnet.train()
            else:
                wnet.eval()

            for i, data in enumerate(loaders[phase]):

                # for i in range(10000):
                x, y_true, _, _, fname, slice_num = data
                x, y_true = x.to(device, dtype=torch.float), y_true.to(
                    device, dtype=torch.float)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    y_pred = wnet(x)
                    if params.lossFunction == 'mse':
                        loss = F.mse_loss(y_pred, y_true)
                    elif params.lossFunction == 'l1':
                        loss = F.l1_loss(y_pred, y_true)
                    elif params.lossFunction == 'ssim':
                        # standard SSIM
                        loss = 0.16 * F.l1_loss(y_pred, y_true) + 0.84 * (
                            1 - ssim_module(y_pred, y_true))
                    elif params.lossFunction == 'msssim':
                        # loss = 0.16 * F.l1_loss(y_pred, y_true) + 0.84 * (1 - msssim_module(y_pred, y_true))
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2]))
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2]))
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = msssimLoss(prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2]))
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2]))
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = vifLoss(prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'mse+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.15 * F.mse_loss(
                            prediction_abs_flat,
                            target_abs_flat) + 0.85 * vifLoss(
                                prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'l1+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.146 * F.l1_loss(
                            y_pred, y_true) + 0.854 * vifLoss(
                                prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'msssim+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.66 * msssimLoss(
                            prediction_abs_flat,
                            target_abs_flat) + 0.33 * vifLoss(
                                prediction_abs_flat, target_abs_flat)

                    if not math.isnan(loss.item()) and loss.item(
                    ) < 2 * best_validation_loss:  # avoid nan/spike values
                        verbose['loss_' + phase].append(loss.item())
                        writer.add_scalar(
                            'Loss/' + phase + '_epoch_' + str(epoch),
                            loss.item(), i)

                    if log_metrics and (
                        (i % params.verbose_gap == 0) or
                        (phase == 'valid' and epoch > params.verbose_delay)):
                        y_true_copy = y_true.detach().cpu().numpy()
                        y_pred_copy = y_pred.detach().cpu().numpy()
                        y_true_copy = y_true_copy[:, ::
                                                  2, :, :] + 1j * y_true_copy[:,
                                                                              1::
                                                                              2, :, :]
                        y_pred_copy = y_pred_copy[:, ::
                                                  2, :, :] + 1j * y_pred_copy[:,
                                                                              1::
                                                                              2, :, :]
                        if params.architecture[-1] == 'k':
                            # transform kspace to image domain
                            y_true_copy = np.fft.ifft2(y_true_copy,
                                                       axes=(2, 3))
                            y_pred_copy = np.fft.ifft2(y_pred_copy,
                                                       axes=(2, 3))

                        # Sum of squares
                        sos_true = np.sqrt(
                            (np.abs(y_true_copy)**2).sum(axis=1))
                        sos_pred = np.sqrt(
                            (np.abs(y_pred_copy)**2).sum(axis=1))
                        '''
                        # Normalization according to: extract_challenge_metrics.ipynb
                        sos_true_max = sos_true.max(axis = (1,2),keepdims = True)
                        sos_true_org = sos_true/sos_true_max
                        sos_pred_org = sos_pred/sos_true_max
                        # Normalization by normalzing with ref with max_ref and rec with max_rec, respectively
                        sos_true_max = sos_true.max(axis = (1,2),keepdims = True)
                        sos_true_mod = sos_true/sos_true_max
                        sos_pred_max = sos_pred.max(axis = (1,2),keepdims = True)
                        sos_pred_mod = sos_pred/sos_pred_max
                        '''
                        '''
                        # normalization by mean and std
                        std = sos_pred.std(axis=(1, 2), keepdims=True)
                        mean = sos_pred.mean(axis=(1, 2), keepdims=True)
                        sos_pred_std = (sos_pred-mean) / std
                        std = sos_true.std(axis=(1, 2), keepdims=True)
                        mean = sos_pred.mean(axis=(1, 2), keepdims=True)
                        sos_true_std = (sos_true-mean) / std
                        '''
                        '''
                        ssim, psnr, vif = metrics(sos_pred_org, sos_true_org)
                        ssim_mod, psnr_mod, vif_mod = metrics(sos_pred_mod, sos_true_mod)
                        '''
                        sos_true_max = sos_true.max(axis=(1, 2), keepdims=True)
                        sos_true_org = sos_true / sos_true_max
                        sos_pred_org = sos_pred / sos_true_max

                        ssim, psnr, vif = metrics(sos_pred, sos_true)
                        ssim_normed, psnr_normed, vif_normed = metrics(
                            sos_pred_org, sos_true_org)

                        verbose['ssim_' + phase].append(np.mean(ssim_normed))
                        verbose['psnr_' + phase].append(np.mean(psnr_normed))
                        verbose['vif_' + phase].append(np.mean(vif_normed))
                        '''
                        print("===Normalization according to: extract_challenge_metrics.ipynb===")
                        print("SSIM: ", verbose['ssim_'+phase][-1])
                        print("PSNR: ", verbose['psnr_'+phase][-1])
                        print("VIF: ",  verbose['vif_' +phase][-1])
                        print("===Normalization by normalzing with ref with max_ref and rec with max_rec, respectively===")
                        print("SSIM_mod: ", np.mean(ssim_mod))
                        print("PSNR_mod: ", np.mean(psnr_mod))
                        print("VIF_mod: ",  np.mean(vif_mod))
                        print("===Normalization by dividing by the standard deviation of ref and rec, respectively===")
                        '''
                        print("Epoch: ", epoch)
                        print("SSIM: ", np.mean(ssim))
                        print("PSNR: ", np.mean(psnr))
                        print("VIF: ", np.mean(vif))

                        print("SSIM_normed: ", verbose['ssim_' + phase][-1])
                        print("PSNR_normed: ", verbose['psnr_' + phase][-1])
                        print("VIF_normed: ", verbose['vif_' + phase][-1])
                        '''
                        if True: #verbose['vif_' + phase][-1] < 0.4:
                            plt.figure(figsize=(9, 6), dpi=150)
                            gs1 = gridspec.GridSpec(3, 2)
                            gs1.update(wspace=0.002, hspace=0.1)
                            plt.subplot(gs1[0])
                            plt.imshow(sos_true[0], cmap="gray")
                            plt.axis("off")
                            plt.subplot(gs1[1])
                            plt.imshow(sos_pred[0], cmap="gray")
                            plt.axis("off")
                            plt.show()
                            # plt.pause(10)
                            # plt.close()
                        '''
                        writer.add_scalar(
                            'SSIM/' + phase + '_epoch_' + str(epoch),
                            verbose['ssim_' + phase][-1], i)
                        writer.add_scalar(
                            'PSNR/' + phase + '_epoch_' + str(epoch),
                            verbose['psnr_' + phase][-1], i)
                        writer.add_scalar(
                            'VIF/' + phase + '_epoch_' + str(epoch),
                            verbose['vif_' + phase][-1], i)

                    print('Loss ' + phase + ': ', loss.item())

                    if phase == 'train':
                        if loss.item() < 2 * best_validation_loss:
                            loss.backward()
                            optimizer.step()

        # Calculate Averages
        psnr_mean = np.mean(verbose['psnr_valid'])
        ssim_mean = np.mean(verbose['ssim_valid'])
        vif_mean = np.mean(verbose['vif_valid'])
        validation_metrics = 0.2 * psnr_mean + 0.4 * ssim_mean + 0.4 * vif_mean

        valid_avg_loss_of_current_epoch = np.mean(verbose['loss_valid'])
        writer.add_scalar('AvgLoss/+train_epoch_' + str(epoch),
                          np.mean(verbose['loss_train']), epoch)
        writer.add_scalar('AvgLoss/+valid_epoch_' + str(epoch),
                          np.mean(verbose['loss_valid']), epoch)
        writer.add_scalar('AvgSSIM/+train_epoch_' + str(epoch),
                          np.mean(verbose['ssim_train']), epoch)
        writer.add_scalar('AvgSSIM/+valid_epoch_' + str(epoch), ssim_mean,
                          epoch)
        writer.add_scalar('AvgPSNR/+train_epoch_' + str(epoch),
                          np.mean(verbose['psnr_train']), epoch)
        writer.add_scalar('AvgPSNR/+valid_epoch_' + str(epoch), psnr_mean,
                          epoch)
        writer.add_scalar('AvgVIF/+train_epoch_' + str(epoch),
                          np.mean(verbose['vif_train']), epoch)
        writer.add_scalar('AvgVIF/+valid_epoch_' + str(epoch), vif_mean, epoch)

        verbose['loss_train'], verbose['loss_valid'], verbose['psnr_train'], verbose['psnr_valid'], \
        verbose['ssim_train'], verbose['ssim_valid'], verbose['vif_train'], verbose['vif_valid'] = ([] for i in
                                                                                                    range(8))

        # Save Networks/Checkpoints
        if best_validation_metrics > validation_metrics:
            best_validation_metrics = validation_metrics
            best_validation_loss = valid_avg_loss_of_current_epoch
            save_checkpoint(
                wnet, params.model_save_path, wnet_identifier, {
                    'epoch': epoch + 1,
                    'state_dict': wnet.state_dict(),
                    'best_validation_loss': best_validation_loss,
                    'optimizer': optimizer.state_dict(),
                }, True)
        else:
            save_checkpoint(
                wnet, params.model_save_path, wnet_identifier, {
                    'epoch': epoch + 1,
                    'state_dict': wnet.state_dict(),
                    'best_validation_loss': best_validation_loss,
                    'optimizer': optimizer.state_dict(),
                }, False)