Esempio n. 1
0
def test_shapes(
    batch_shape: Tuple, num_pieces: int, num_samples: int, serialize_fn
):
    gamma = mx.nd.ones(shape=(*batch_shape,))
    slopes = mx.nd.ones(shape=(*batch_shape, num_pieces))  # all positive
    knot_spacings = (
        mx.nd.ones(shape=(*batch_shape, num_pieces)) / num_pieces
    )  # positive and sum to 1
    target = mx.nd.ones(shape=batch_shape)  # shape of gamma

    distr = PiecewiseLinear(
        gamma=gamma, slopes=slopes, knot_spacings=knot_spacings
    )
    distr = serialize_fn(distr)

    # assert that the parameters and target have proper shapes
    assert gamma.shape == target.shape
    assert knot_spacings.shape == slopes.shape
    assert len(gamma.shape) + 1 == len(knot_spacings.shape)

    # assert that batch_shape is computed properly
    assert distr.batch_shape == batch_shape

    # assert that shapes of original parameters are correct
    assert distr.b.shape == slopes.shape
    assert distr.knot_positions.shape == knot_spacings.shape

    # assert that the shape of crps is correct
    assert distr.crps(target).shape == batch_shape

    # assert that the quantile shape is correct when computing the quantile values at knot positions - used for a_tilde
    assert distr.quantile_internal(knot_spacings, axis=-2).shape == (
        *batch_shape,
        num_pieces,
    )

    # assert that the samples and the quantile values shape when num_samples is None is correct
    samples = distr.sample()
    assert samples.shape == batch_shape
    assert distr.quantile_internal(samples).shape == batch_shape

    # assert that the samples and the quantile values shape when num_samples is not None is correct
    samples = distr.sample(num_samples)
    assert samples.shape == (num_samples, *batch_shape)
    assert distr.quantile_internal(samples, axis=0).shape == (
        num_samples,
        *batch_shape,
    )
def test_piecewise_linear(
    gamma: float,
    slopes: np.ndarray,
    knot_spacings: np.ndarray,
    hybridize: bool,
) -> None:
    """
    Test to check that minimizing the CRPS recovers the quantile function
    """
    num_samples = 500  # use a few samples for timeout failure

    gammas = mx.nd.zeros((num_samples, )) + gamma
    slopess = mx.nd.zeros((num_samples, len(slopes))) + mx.nd.array(slopes)
    knot_spacingss = mx.nd.zeros(
        (num_samples, len(knot_spacings))) + mx.nd.array(knot_spacings)

    pwl_sqf = PiecewiseLinear(gammas, slopess, knot_spacingss)

    samples = pwl_sqf.sample()

    # Parameter initialization
    gamma_init = gamma - START_TOL_MULTIPLE * TOL * gamma
    slopes_init = slopes - START_TOL_MULTIPLE * TOL * slopes
    knot_spacings_init = knot_spacings
    # We perturb knot spacings such that even after the perturbation they sum to 1.
    mid = len(slopes) // 2
    knot_spacings_init[:mid] = (knot_spacings[:mid] -
                                START_TOL_MULTIPLE * TOL * knot_spacings[:mid])
    knot_spacings_init[mid:] = (knot_spacings[mid:] +
                                START_TOL_MULTIPLE * TOL * knot_spacings[mid:])

    init_biases = [gamma_init, slopes_init, knot_spacings_init]

    # check if it returns original parameters of mapped
    gamma_hat, slopes_hat, knot_spacings_hat = maximum_likelihood_estimate_sgd(
        PiecewiseLinearOutput(len(slopes)),
        samples,
        init_biases=init_biases,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.01),
        num_epochs=PositiveInt(20),
    )

    # Since the problem is highly non-convex we may not be able to recover the exact parameters
    # Here we check if the estimated parameters yield similar function evaluations at different quantile levels.
    quantile_levels = np.arange(0.1, 1.0, 0.1)

    # create a LinearSplines instance with the estimated parameters to have access to .quantile
    pwl_sqf_hat = PiecewiseLinear(
        mx.nd.array(gamma_hat),
        mx.nd.array(slopes_hat).expand_dims(axis=0),
        mx.nd.array(knot_spacings_hat).expand_dims(axis=0),
    )

    # Compute quantiles with the estimated parameters
    quantiles_hat = np.squeeze(
        pwl_sqf_hat.quantile_internal(
            mx.nd.array(quantile_levels).expand_dims(axis=0),
            axis=1).asnumpy())

    # Compute quantiles with the original parameters
    # Since params is replicated across samples we take only the first entry
    quantiles = np.squeeze(
        pwl_sqf.quantile_internal(
            mx.nd.array(quantile_levels).expand_dims(axis=0).repeat(
                axis=0, repeats=num_samples),
            axis=1,
        ).asnumpy()[0, :])

    for ix, (quantile, quantile_hat) in enumerate(zip(quantiles,
                                                      quantiles_hat)):
        assert np.abs(quantile_hat - quantile) < TOL * quantile, (
            f"quantile level {quantile_levels[ix]} didn't match:"
            f" "
            f"q = {quantile}, q_hat = {quantile_hat}")