def test_GroupNormalizer(kwargs, groups):
    data = pd.DataFrame(dict(a=[1, 1, 2, 2, 3], b=[1.1, 1.1, 1.0, 5.0, 1.1]))
    defaults = dict(
        method="standard", log_scale=False, coerce_positive=False, center=True, log_zero_value=0.0, scale_by_group=False
    )
    defaults.update(kwargs)
    kwargs = defaults
    kwargs["groups"] = groups
    kwargs["scale_by_group"] = kwargs["scale_by_group"] and len(kwargs["groups"]) > 0

    if kwargs["coerce_positive"] and kwargs["log_scale"]:
        with pytest.raises(AssertionError):
            normalizer = GroupNormalizer(**kwargs)
    else:
        if kwargs["coerce_positive"]:
            data.b = data.b - 2.0
        normalizer = GroupNormalizer(**kwargs)
        encoded = normalizer.fit_transform(data["b"], data)

        test_data = dict(
            prediction=torch.tensor([encoded.iloc[0]]),
            target_scale=torch.tensor(normalizer.get_parameters([1])).unsqueeze(0),
        )

        if kwargs["coerce_positive"]:
            assert (normalizer(test_data) >= 0).all(), "Inverse transform should yield only positive values"
        else:
            assert torch.isclose(
                normalizer(test_data), torch.tensor(data.b.iloc[0]), atol=1e-5
            ).all(), "Inverse transform should reverse transform"
def test_GroupNormalizer(kwargs, groups):
    data = pd.DataFrame(dict(a=[1, 1, 2, 2, 3], b=[1.1, 1.1, 1.0, 5.0, 1.1]))
    defaults = dict(method="standard",
                    transformation=None,
                    center=True,
                    scale_by_group=False)
    defaults.update(kwargs)
    kwargs = defaults
    kwargs["groups"] = groups
    kwargs["scale_by_group"] = kwargs["scale_by_group"] and len(
        kwargs["groups"]) > 0

    if kwargs.get("transformation") in ["relu", "softplus"]:
        data.b = data.b - 2.0
    normalizer = GroupNormalizer(**kwargs)
    encoded = normalizer.fit_transform(data["b"], data)

    test_data = dict(
        prediction=torch.tensor([encoded[0]]),
        target_scale=torch.tensor(normalizer.get_parameters([1])).unsqueeze(0),
    )

    if kwargs.get("transformation") in ["relu", "softplus", "log1p"]:
        assert (normalizer(test_data) >=
                0).all(), "Inverse transform should yield only positive values"
    else:
        assert torch.isclose(
            normalizer(test_data), torch.tensor(data.b.iloc[0]),
            atol=1e-5).all(), "Inverse transform should reverse transform"