def test_NormalDistributionLoss(log_scale, center, coerce_positive): mean = 1000.0 std = 200.0 n = 100000 target = NormalDistributionLoss.distribution_class(loc=mean, scale=std).sample_n(n) if log_scale or coerce_positive: target = target.abs() if log_scale and coerce_positive: return # combination invalid for normalizer (tested somewhere else) normalizer = TorchNormalizer(log_scale=log_scale, center=center, coerce_positive=coerce_positive) normalized_target = normalizer.fit_transform(target).view(1, -1) target_scale = normalizer.get_parameters().unsqueeze(0) scale = torch.ones_like(normalized_target) * normalized_target.std() parameters = torch.stack( [normalized_target, scale], dim=-1, ) loss = NormalDistributionLoss() if log_scale or coerce_positive: with pytest.raises(AssertionError): rescaled_parameters = loss.rescale_parameters( parameters, target_scale=target_scale, transformer=normalizer) else: rescaled_parameters = loss.rescale_parameters( parameters, target_scale=target_scale, transformer=normalizer) samples = loss.sample_n(rescaled_parameters, 1) assert torch.isclose(torch.as_tensor(mean), samples.mean(), atol=0.1, rtol=0.2) if center: # if not centered, softplus distorts std too much for testing assert torch.isclose(torch.as_tensor(std), samples.std(), atol=0.1, rtol=0.7)
def test_NormalDistributionLoss(center, transformation): mean = 1000.0 std = 200.0 n = 100000 target = NormalDistributionLoss.distribution_class(loc=mean, scale=std).sample_n(n) if transformation in ["log", "log1p", "relu", "softplus"]: target = target.abs() normalizer = TorchNormalizer(center=center, transformation=transformation) normalized_target = normalizer.fit_transform(target).view(1, -1) target_scale = normalizer.get_parameters().unsqueeze(0) scale = torch.ones_like(normalized_target) * normalized_target.std() parameters = torch.stack( [normalized_target, scale], dim=-1, ) loss = NormalDistributionLoss() if transformation in [ "logit", "log", "log1p", "softplus", "relu", "logit" ]: with pytest.raises(AssertionError): rescaled_parameters = loss.rescale_parameters( parameters, target_scale=target_scale, encoder=normalizer) else: rescaled_parameters = loss.rescale_parameters( parameters, target_scale=target_scale, encoder=normalizer) samples = loss.sample_n(rescaled_parameters, 1) assert torch.isclose(torch.as_tensor(mean), samples.mean(), atol=0.1, rtol=0.2) if center: # if not centered, softplus distorts std too much for testing assert torch.isclose(torch.as_tensor(std), samples.std(), atol=0.1, rtol=0.7)