Exemplo n.º 1
0
def test_NormalDistributionLoss(center, transformation):
    mean = 1000.0
    std = 200.0
    n = 100000
    target = NormalDistributionLoss.distribution_class(loc=mean, scale=std).sample((n,))
    if transformation in ["log", "log1p", "relu", "softplus"]:
        target = target.abs()
    normalizer = TorchNormalizer(center=center, transformation=transformation)
    normalized_target = normalizer.fit_transform(target).view(1, -1)
    target_scale = normalizer.get_parameters().unsqueeze(0)
    scale = torch.ones_like(normalized_target) * normalized_target.std()
    parameters = torch.stack(
        [normalized_target, scale],
        dim=-1,
    )
    loss = NormalDistributionLoss()
    if transformation in ["logit", "log", "log1p", "softplus", "relu", "logit"]:
        with pytest.raises(AssertionError):
            rescaled_parameters = loss.rescale_parameters(parameters, target_scale=target_scale, encoder=normalizer)
    else:
        rescaled_parameters = loss.rescale_parameters(parameters, target_scale=target_scale, encoder=normalizer)
        samples = loss.sample(rescaled_parameters, 1)
        assert torch.isclose(torch.as_tensor(mean), samples.mean(), atol=0.1, rtol=0.2)
        if center:  # if not centered, softplus distorts std too much for testing
            assert torch.isclose(torch.as_tensor(std), samples.std(), atol=0.1, rtol=0.7)
Exemplo n.º 2
0
def test_NormalDistributionLoss(center, transformation):
    mean = 1.0
    std = 0.1
    n = 100000
    target = NormalDistributionLoss.distribution_class(loc=mean,
                                                       scale=std).sample((n, ))
    normalizer = TorchNormalizer(center=center, transformation=transformation)
    if transformation in ["log", "log1p", "relu", "softplus"]:
        target = target.abs()
    target = normalizer.inverse_preprocess(target)

    normalized_target = normalizer.fit_transform(target).view(1, -1)
    target_scale = normalizer.get_parameters().unsqueeze(0)
    scale = torch.ones_like(normalized_target) * normalized_target.std()
    parameters = torch.stack(
        [normalized_target, scale],
        dim=-1,
    )
    loss = NormalDistributionLoss()
    rescaled_parameters = loss.rescale_parameters(parameters,
                                                  target_scale=target_scale,
                                                  encoder=normalizer)
    samples = loss.sample(rescaled_parameters, 1)
    assert torch.isclose(target.mean(), samples.mean(), atol=0.1, rtol=0.5)
    if center:  # if not centered, softplus distorts std too much for testing
        assert torch.isclose(target.std(), samples.std(), atol=0.1, rtol=0.7)
def test_NormalDistributionLoss(log_scale, center, coerce_positive):
    mean = 1000.0
    std = 200.0
    n = 100000
    target = NormalDistributionLoss.distribution_class(loc=mean,
                                                       scale=std).sample_n(n)
    if log_scale or coerce_positive:
        target = target.abs()
    if log_scale and coerce_positive:
        return  # combination invalid for normalizer (tested somewhere else)
    normalizer = TorchNormalizer(log_scale=log_scale,
                                 center=center,
                                 coerce_positive=coerce_positive)
    normalized_target = normalizer.fit_transform(target).view(1, -1)
    target_scale = normalizer.get_parameters().unsqueeze(0)
    scale = torch.ones_like(normalized_target) * normalized_target.std()
    parameters = torch.stack(
        [normalized_target, scale],
        dim=-1,
    )
    loss = NormalDistributionLoss()
    if log_scale or coerce_positive:
        with pytest.raises(AssertionError):
            rescaled_parameters = loss.rescale_parameters(
                parameters, target_scale=target_scale, transformer=normalizer)
    else:
        rescaled_parameters = loss.rescale_parameters(
            parameters, target_scale=target_scale, transformer=normalizer)
        samples = loss.sample_n(rescaled_parameters, 1)
        assert torch.isclose(torch.as_tensor(mean),
                             samples.mean(),
                             atol=0.1,
                             rtol=0.2)
        if center:  # if not centered, softplus distorts std too much for testing
            assert torch.isclose(torch.as_tensor(std),
                                 samples.std(),
                                 atol=0.1,
                                 rtol=0.7)