Exemplo n.º 1
0
def test_values(
    distr: PiecewiseLinear,
    target: List[float],
    expected_target_cdf: List[float],
    expected_target_crps: List[float],
    serialize_fn,
):
    distr = serialize_fn(distr)
    target = mx.nd.array(target).reshape(shape=(len(target), ))
    expected_target_cdf = np.array(expected_target_cdf).reshape(
        (len(expected_target_cdf), ))
    expected_target_crps = np.array(expected_target_crps).reshape(
        (len(expected_target_crps), ))

    assert all(np.isclose(distr.cdf(target).asnumpy(), expected_target_cdf))
    assert all(np.isclose(distr.crps(target).asnumpy(), expected_target_crps))

    # compare with empirical cdf from samples
    num_samples = 100_000
    samples = distr.sample(num_samples).asnumpy()
    assert np.isfinite(samples).all()

    emp_cdf, edges = empirical_cdf(samples)
    calc_cdf = distr.cdf(mx.nd.array(edges)).asnumpy()
    assert np.allclose(calc_cdf[1:, :], emp_cdf, atol=1e-2)
Exemplo n.º 2
0
def test_simple_symmetric():
    gamma = mx.nd.array([-1.0])
    slopes = mx.nd.array([[2.0, 2.0]])
    knot_spacings = mx.nd.array([[0.5, 0.5]])

    distr = PiecewiseLinear(gamma=gamma,
                            slopes=slopes,
                            knot_spacings=knot_spacings)

    assert distr.cdf(mx.nd.array([-2.0])).asnumpy().item() == 0.0
    assert distr.cdf(mx.nd.array([+2.0])).asnumpy().item() == 1.0

    expected_crps = np.array([1.0 + 2.0 / 3.0])

    assert np.allclose(
        distr.crps(mx.nd.array([-2.0])).asnumpy(), expected_crps)

    assert np.allclose(distr.crps(mx.nd.array([2.0])).asnumpy(), expected_crps)
Exemplo n.º 3
0
def test_shapes(
    batch_shape: Tuple, num_pieces: int, num_samples: int, serialize_fn
):
    gamma = mx.nd.ones(shape=(*batch_shape,))
    slopes = mx.nd.ones(shape=(*batch_shape, num_pieces))  # all positive
    knot_spacings = (
        mx.nd.ones(shape=(*batch_shape, num_pieces)) / num_pieces
    )  # positive and sum to 1
    target = mx.nd.ones(shape=batch_shape)  # shape of gamma

    distr = PiecewiseLinear(
        gamma=gamma, slopes=slopes, knot_spacings=knot_spacings
    )
    distr = serialize_fn(distr)

    # assert that the parameters and target have proper shapes
    assert gamma.shape == target.shape
    assert knot_spacings.shape == slopes.shape
    assert len(gamma.shape) + 1 == len(knot_spacings.shape)

    # assert that batch_shape is computed properly
    assert distr.batch_shape == batch_shape

    # assert that shapes of original parameters are correct
    assert distr.b.shape == slopes.shape
    assert distr.knot_positions.shape == knot_spacings.shape

    # assert that the shape of crps is correct
    assert distr.crps(target).shape == batch_shape

    # assert that the quantile shape is correct when computing the quantile values at knot positions - used for a_tilde
    assert distr.quantile_internal(knot_spacings, axis=-2).shape == (
        *batch_shape,
        num_pieces,
    )

    # assert that the samples and the quantile values shape when num_samples is None is correct
    samples = distr.sample()
    assert samples.shape == batch_shape
    assert distr.quantile_internal(samples).shape == batch_shape

    # assert that the samples and the quantile values shape when num_samples is not None is correct
    samples = distr.sample(num_samples)
    assert samples.shape == (num_samples, *batch_shape)
    assert distr.quantile_internal(samples, axis=0).shape == (
        num_samples,
        *batch_shape,
    )