def test_NormalDistributionLoss(log_scale, center, coerce_positive):
    mean = 1000.0
    std = 200.0
    n = 100000
    target = NormalDistributionLoss.distribution_class(loc=mean,
                                                       scale=std).sample_n(n)
    if log_scale or coerce_positive:
        target = target.abs()
    if log_scale and coerce_positive:
        return  # combination invalid for normalizer (tested somewhere else)
    normalizer = TorchNormalizer(log_scale=log_scale,
                                 center=center,
                                 coerce_positive=coerce_positive)
    normalized_target = normalizer.fit_transform(target).view(1, -1)
    target_scale = normalizer.get_parameters().unsqueeze(0)
    scale = torch.ones_like(normalized_target) * normalized_target.std()
    parameters = torch.stack(
        [normalized_target, scale],
        dim=-1,
    )
    loss = NormalDistributionLoss()
    if log_scale or coerce_positive:
        with pytest.raises(AssertionError):
            rescaled_parameters = loss.rescale_parameters(
                parameters, target_scale=target_scale, transformer=normalizer)
    else:
        rescaled_parameters = loss.rescale_parameters(
            parameters, target_scale=target_scale, transformer=normalizer)
        samples = loss.sample_n(rescaled_parameters, 1)
        assert torch.isclose(torch.as_tensor(mean),
                             samples.mean(),
                             atol=0.1,
                             rtol=0.2)
        if center:  # if not centered, softplus distorts std too much for testing
            assert torch.isclose(torch.as_tensor(std),
                                 samples.std(),
                                 atol=0.1,
                                 rtol=0.7)
def test_NormalDistributionLoss(center, transformation):
    mean = 1000.0
    std = 200.0
    n = 100000
    target = NormalDistributionLoss.distribution_class(loc=mean,
                                                       scale=std).sample_n(n)
    if transformation in ["log", "log1p", "relu", "softplus"]:
        target = target.abs()
    normalizer = TorchNormalizer(center=center, transformation=transformation)
    normalized_target = normalizer.fit_transform(target).view(1, -1)
    target_scale = normalizer.get_parameters().unsqueeze(0)
    scale = torch.ones_like(normalized_target) * normalized_target.std()
    parameters = torch.stack(
        [normalized_target, scale],
        dim=-1,
    )
    loss = NormalDistributionLoss()
    if transformation in [
            "logit", "log", "log1p", "softplus", "relu", "logit"
    ]:
        with pytest.raises(AssertionError):
            rescaled_parameters = loss.rescale_parameters(
                parameters, target_scale=target_scale, encoder=normalizer)
    else:
        rescaled_parameters = loss.rescale_parameters(
            parameters, target_scale=target_scale, encoder=normalizer)
        samples = loss.sample_n(rescaled_parameters, 1)
        assert torch.isclose(torch.as_tensor(mean),
                             samples.mean(),
                             atol=0.1,
                             rtol=0.2)
        if center:  # if not centered, softplus distorts std too much for testing
            assert torch.isclose(torch.as_tensor(std),
                                 samples.std(),
                                 atol=0.1,
                                 rtol=0.7)