def test_gmsd_loss_modes(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: for reduction in ['mean', 'sum', 'none']: GMSDLoss(reduction=reduction)(prediction.to(device), target.to(device)) for reduction in ['DEADBEEF', 'random']: with pytest.raises(KeyError): GMSDLoss(reduction=reduction)(prediction.to(device), target.to(device))
def test_gmsd_loss_modes(x, y, device: str) -> None: for reduction in ['mean', 'sum', 'none']: GMSDLoss(reduction=reduction)(x.to(device), y.to(device)) for reduction in ['DEADBEEF', 'random']: with pytest.raises(ValueError): GMSDLoss(reduction=reduction)(x.to(device), y.to(device))
def test_gmsd_loss_supports_different_data_ranges(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: prediction_255 = (prediction * 255).type(torch.uint8) target_255 = (target * 255).type(torch.uint8) loss = GMSDLoss() measure = loss(prediction.to(device), target.to(device)) loss_255 = GMSDLoss(data_range=255) measure_255 = loss_255(prediction_255.to(device), target_255.to(device)) diff = torch.abs(measure_255 - measure) assert diff <= 1e-4, f'Result for same tensor with different data_range should be the same, got {diff}'
def test_gmsd_loss_supports_different_data_ranges(x, y, data_range, device: str) -> None: x_scaled = (x * data_range).type(torch.uint8) y_scaled = (y * data_range).type(torch.uint8) loss_scaled = GMSDLoss(data_range=data_range) measure_scaled = loss_scaled(x_scaled.to(device), y_scaled.to(device)) loss = GMSDLoss() measure = loss( x_scaled.to(device) / float(data_range), y_scaled.to(device) / float(data_range), ) diff = torch.abs(measure_scaled - measure) assert diff <= 1e-6, f'Result for same tensor with different data_range should be the same, got {diff}'
def test_gmsd_loss_supports_different_data_ranges( prediction: torch.Tensor, target: torch.Tensor, data_range, device: str) -> None: prediction_scaled = (prediction * data_range).type(torch.uint8) target_scaled = (target * data_range).type(torch.uint8) loss_scaled = GMSDLoss(data_range=data_range) measure_scaled = loss_scaled(prediction_scaled.to(device), target_scaled.to(device)) loss = GMSDLoss() measure = loss( prediction_scaled.to(device) / float(data_range), target_scaled.to(device) / float(data_range), ) diff = torch.abs(measure_scaled - measure) assert diff <= 1e-6, f'Result for same tensor with different data_range should be the same, got {diff}'
def test_gmsd_loss_zero_for_equal_tensors(prediction: torch.Tensor, device: str) -> None: loss = GMSDLoss() target = prediction.clone() measure = loss(prediction.to(device), target.to(device)) assert measure.abs( ) <= 1e-6, f'GMSD for equal tensors must be 0, got {measure}'
def test_gmsd_loss_forward_backward(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: prediction.requires_grad_() loss_value = GMSDLoss()(prediction.to(device), target.to(device)) loss_value.backward() assert torch.isfinite(prediction.grad).all(), LEAF_VARIABLE_ERROR_MESSAGE
def test_gmsd_loss_supports_greyscale_tensors(device: str) -> None: loss = GMSDLoss() target = torch.ones(2, 1, 96, 96) prediction = torch.zeros(2, 1, 96, 96) loss(prediction.to(device), target.to(device))
def test_gmsd_loss_raises_if_tensors_have_different_types(target: torch.Tensor, device: str) -> None: wrong_type_predictions = [list(range(10)), np.arange(10)] for wrong_type_prediction in wrong_type_predictions: with pytest.raises(AssertionError): GMSDLoss()(wrong_type_prediction, target.to(device))
def test_gmsd_loss_supports_greyscale_tensors(device: str) -> None: loss = GMSDLoss() y = torch.ones(2, 1, 96, 96) x = torch.zeros(2, 1, 96, 96) loss(x.to(device), y.to(device))
def test_gmsd_loss_raises_if_tensors_have_different_types(y, device: str) -> None: wrong_type_x = [list(range(10)), np.arange(10)] for wrong_x in wrong_type_x: with pytest.raises(AssertionError): GMSDLoss()(wrong_x, y.to(device))
def test_gmsd_loss_zero_for_equal_tensors(x, device: str) -> None: loss = GMSDLoss() y = x.clone() measure = loss(x.to(device), y.to(device)) assert measure.abs() <= 1e-6, f'GMSD for equal tensors must be 0, got {measure}'
def test_gmsd_loss_forward_backward(x, y, device: str) -> None: x.requires_grad_() loss_value = GMSDLoss()(x.to(device), y.to(device)) loss_value.backward() assert torch.isfinite(x.grad).all(), LEAF_VARIABLE_ERROR_MESSAGE