Beispiel #1
0
def test_ssim_loss_grad(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    x.requires_grad_(True)
    loss = SSIMLoss(data_range=1.)(x, y).mean()
    loss.backward()
    assert torch.isfinite(
        x.grad).all(), f'Expected finite gradient values, got {x.grad}'
Beispiel #2
0
def test_ssim_loss_check_kernel_size_is_passed(x: torch.Tensor,
                                               y: torch.Tensor) -> None:
    kernel_sizes = list(range(0, 50))
    for kernel_size in kernel_sizes:
        if kernel_size % 2:
            SSIMLoss(kernel_size=kernel_size)(x, y)
        else:
            with pytest.raises(AssertionError):
                SSIMLoss(kernel_size=kernel_size)(x, y)
Beispiel #3
0
def test_ssim_loss_grad(prediction_target_4d_5d: Tuple[torch.Tensor,
                                                       torch.Tensor],
                        device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    prediction.requires_grad_(True)
    loss = SSIMLoss(data_range=1.)(prediction, target).mean()
    loss.backward()
    assert torch.isfinite(prediction.grad).all(
    ), f'Expected finite gradient values, got {prediction.grad}'
Beispiel #4
0
def test_ssim_loss_raises_if_tensors_have_different_shapes(x_y_4d_5d,
                                                           device) -> None:
    y = x_y_4d_5d[1].to(device)
    dims = [[3], [2, 3], [161, 162], [161, 162]]
    if y.dim() == 5:
        dims += [[2, 3]]
    for size in list(itertools.product(*dims)):
        wrong_shape_x = torch.rand(size).to(y)
        if wrong_shape_x.size() == y.size():
            SSIMLoss()(wrong_shape_x, y)
        else:
            with pytest.raises(AssertionError):
                SSIMLoss()(wrong_shape_x, y)
Beispiel #5
0
def test_ssim_loss_check_available_dimensions() -> None:
    custom_x = torch.rand(256, 256)
    custom_y = torch.rand(256, 256)
    for _ in range(10):
        if custom_x.dim() < 5:
            try:
                SSIMLoss()(custom_x, custom_y)
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                SSIMLoss()(custom_x, custom_y)
        custom_x.unsqueeze_(0)
        custom_y.unsqueeze_(0)
Beispiel #6
0
def test_ssim_loss_raises_if_tensors_have_different_shapes(
        prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor],
        device) -> None:
    target = prediction_target_4d_5d[1].to(device)
    dims = [[3], [2, 3], [161, 162], [161, 162]]
    if target.dim() == 5:
        dims += [[2, 3]]
    for size in list(itertools.product(*dims)):
        wrong_shape_prediction = torch.rand(size).to(target)
        if wrong_shape_prediction.size() == target.size():
            SSIMLoss()(wrong_shape_prediction, target)
        else:
            with pytest.raises(AssertionError):
                SSIMLoss()(wrong_shape_prediction, target)
Beispiel #7
0
def test_ssim_loss_equality(y: torch.Tensor, device: str) -> None:
    y = y.to(device)
    x = y.clone()
    loss = SSIMLoss()(x, y)
    assert torch.allclose(loss, torch.zeros_like(loss)), \
        f'If equal tensors are passed SSIM loss must be equal to 0 '\
        f'(considering floating point operation error up to 1 * 10^-6), got {loss}'
Beispiel #8
0
def test_ssim_loss_is_less_or_equal_to_one(
        ones_zeros_4d_5d: Tuple[torch.Tensor,
                                torch.Tensor], device: str) -> None:
    # Create two maximally different tensors.
    ones = ones_zeros_4d_5d[0].to(device)
    zeros = ones_zeros_4d_5d[1].to(device)
    loss = SSIMLoss()(ones, zeros)
    assert (loss <= 1).all(), f'SSIM loss must be <= 1, got {loss}'
Beispiel #9
0
def test_ssim_loss_symmetry(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    loss = SSIMLoss()
    loss_value = loss(x, y)
    reverse_loss_value = loss(y, x)
    assert torch.allclose(loss_value, reverse_loss_value), \
        f'Expect: SSIMLoss(a, b) == SSIMLoss(b, a), got {loss_value} != {reverse_loss_value}'
Beispiel #10
0
def test_ssim_loss_raises_if_kernel_size_greater_than_image(
        x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    kernel_size = 11
    wrong_size_x = x[:, :, :kernel_size - 1, :kernel_size - 1]
    wrong_size_y = y[:, :, :kernel_size - 1, :kernel_size - 1]
    with pytest.raises(ValueError):
        SSIMLoss(kernel_size=kernel_size)(wrong_size_x, wrong_size_y)
Beispiel #11
0
def test_ssim_loss_symmetry(prediction_target_4d_5d: Tuple[torch.Tensor,
                                                           torch.Tensor],
                            device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    loss = SSIMLoss()
    loss_value = loss(prediction, target)
    reverse_loss_value = loss(target, prediction)
    assert torch.allclose(loss_value, reverse_loss_value), \
        f'Expect: SSIMLoss(a, b) == SSIMLoss(b, a), got {loss_value} != {reverse_loss_value}'
Beispiel #12
0
def test_ssim_loss_raises_if_kernel_size_greater_than_image(
        prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor],
        device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    kernel_size = 11
    wrong_size_prediction = prediction[:, :, :kernel_size - 1, :kernel_size -
                                       1]
    wrong_size_target = target[:, :, :kernel_size - 1, :kernel_size - 1]
    with pytest.raises(ValueError):
        SSIMLoss(kernel_size=kernel_size)(wrong_size_prediction,
                                          wrong_size_target)
Beispiel #13
0
def test_ssim_loss_raise_if_wrong_value_is_estimated(
        test_images: Tuple[torch.Tensor, torch.Tensor], device: str) -> None:
    for x, y in test_images:
        ssim_loss = SSIMLoss(kernel_size=11,
                             kernel_sigma=1.5,
                             data_range=255,
                             reduction='mean')(x.to(device), y.to(device))
        tf_x = tf.convert_to_tensor(x.permute(0, 2, 3, 1).numpy())
        tf_y = tf.convert_to_tensor(y.permute(0, 2, 3, 1).numpy())
        with tf.device('/CPU'):
            tf_ssim = torch.tensor(
                tf.image.ssim(tf_x, tf_y,
                              max_val=255).numpy()).mean().to(device)
        match_accuracy = 2e-4 + 1e-8
        assert torch.isclose(ssim_loss, 1. - tf_ssim, rtol=0, atol=match_accuracy), \
            f'The estimated value must be equal to tensorflow provided one' \
            f'(considering floating point operation error up to {match_accuracy}), ' \
            f'got difference {(ssim_loss - 1. + tf_ssim).abs()}'
Beispiel #14
0
def usage():
    import torch
    from piq import ssim, SSIMLoss

    x = torch.rand(4, 256, 256, requires_grad=True)
    y = torch.rand(4, 256, 256)

    ssim_index: torch.Tensor = ssim(x, y, data_range=1.)
    ssimvalue = ssim_index.detach().numpy()
    assert type(ssim_index.detach().numpy()) == np.ndarray, type(
        ssim_index.detach().numpy())
    print(f"ssim_index: ", )

    # loss = SSIMLoss(data_range=1.)
    loss = SSIMLoss(data_range=1.)
    output2 = loss(x, y)
    output: torch.Tensor = loss(x, x)
    print(output.item())
    output.backward()
Beispiel #15
0
    "out_channels": 2,
}
subnet = Tiramisu

it_net_params = {
    "num_iter": 1,
    "lam": 1.0,
    "lam_learnable": False,
    "final_dc": False,
    "resnet_factor": 1.0,
    "concat_mask": False,
    "multi_slice": False,
}

# ----- training configuration -----
ssimloss = SSIMLoss(data_range=1e1)
maeloss = torch.nn.L1Loss(reduction="mean")


def ssim_l1_loss(pred, tar):
    return 0.7 * ssimloss(
        rotate_real(pred)[:, 0:1, ...],
        rotate_real(tar)[:, 0:1, ...],
    ) + 0.3 * maeloss(
        rotate_real(pred)[:, 0:1, ...],
        rotate_real(tar)[:, 0:1, ...],
    )


train_params = {
    "num_epochs": [40],
Beispiel #16
0
def main():
    args = parse_args()
    config = load_json('./configs', args.config_name)
    config = config[args.encoder_type] 
    
    if torch.cuda.is_available():
        DEVICE = 'cuda'
    else:
        DEVICE = 'cpu'
        print('CUDA is not available, using CPU instead...')

    # ===================Create data loaders=====================
    transforms = get_transforms(roi=config['roi'], grayscale=config['grayscale'])

    train_set = SuperMarioKartDataset(
        Path(args.data_path) / 'train_files.txt', transforms=transforms
    )
    valid_set = SuperMarioKartDataset(
        Path(args.data_path) / 'valid_files.txt', transforms=transforms
    )

    train_loader = DataLoader(train_set, batch_size=config['batch_size'], num_workers=4)
    valid_loader = DataLoader(train_set, batch_size=1, num_workers=4)

    # ===================Export some val images=====================
    if not args.save_path:
        save_path = args.data_path / (args.encoder_type + '_results')
    else:
        save_path = args.save_path
    save_path.mkdir(exist_ok=True)

    # ===================Instantiate models=====================
    img_shape = train_set.image_shape
    if args.encoder_type == 'pca':
        network = autoencoders.PCA(img_shape, config['latent_dim']).cuda()
    elif args.encoder_type == 'mlp':
        network = autoencoders.MLP(img_shape, config['latent_dim']).cuda()

    optimizer = torch.optim.Adam(network.parameters())
    ssim_loss = SSIMLoss(data_range=1.)
    mse_loss = nn.MSELoss()

    load_path = args.load_path
    if load_path and load_path.exists():
        network.load_state_dict(torch.load(load_path / "best_encoder.pt"))

    best_val_loss = 999
    cnt_bad_epocs = 0

    for epoch in range(1, config['n_epochs'] + 1):
        train_losses, valid_losses = [], []
        epc_save_path = save_path / f'epoch_{epoch}'
        epc_save_path.mkdir(exist_ok=True)
        # ===================Training=====================
        for x_t in train_loader:
            x_t = x_t.cuda()

            encoding, decoding = network(x_t)
            step_mse = mse_loss(decoding, x_t)
            loss = mse_loss(decoding, x_t) + 0.1 * ssim_loss(decoding, x_t)
            train_losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # ===================Validation=====================
        for i, x_t in enumerate(valid_loader):
            x_t = x_t.cuda()
            
            with torch.no_grad():
                encoding, decoding = network(x_t)
                step_mse = mse_loss(decoding, x_t)
                step_ssim = 0.1 * ssim_loss(decoding, x_t)
                loss = step_mse + step_ssim

            if i < args.viz_samples: # display only the first 10 samples
                input_image = np.array(ToPILImage()(x_t.cpu()[0]))
                decoded_image = np.array(ToPILImage()(decoding.cpu()[0]))
                merged_image = np.concatenate((input_image, decoded_image))
                imsave(epc_save_path / f'result_{i}.png', merged_image)

            valid_losses.append(step_mse.item())

        avg_trn_loss = round(np.mean(train_losses), 4)
        avg_val_loss = round(np.mean(valid_losses), 4)

        print(f'Epoch - {epoch} | Avg Train MSE - {avg_trn_loss} | Avg Val MSE - {avg_val_loss}')

        # ===================Checkpointing=====================
        if best_val_loss > avg_val_loss:
            cnt_bad_epocs = 0
            best_val_loss = avg_val_loss
            torch.save(network.state_dict(), save_path / 'best_encoder.pt')
        
        # ===================Early stopping=====================
        if best_val_loss < avg_val_loss:
            cnt_bad_epocs += 1

        if cnt_bad_epocs == args.patience:
            print('Network is not improving. Stopping training...') 
            break
Beispiel #17
0
def test_ssim_loss_raises_if_tensors_have_different_types(
        y: torch.Tensor) -> None:
    wrong_type_x = list(range(10))
    with pytest.raises(AssertionError):
        SSIMLoss()(wrong_type_x, y)