def test_shapes(batch_shape: Tuple, num_pieces: int, num_samples: int): gamma = mx.nd.ones(shape=(*batch_shape,)) slopes = mx.nd.ones(shape=(*batch_shape, num_pieces)) # all positive knot_spacings = ( mx.nd.ones(shape=(*batch_shape, num_pieces)) / num_pieces ) # positive and sum to 1 target = mx.nd.ones(shape=batch_shape) # shape of gamma distr = PiecewiseLinear( gamma=gamma, slopes=slopes, knot_spacings=knot_spacings ) # assert that the parameters and target have proper shapes assert gamma.shape == target.shape assert knot_spacings.shape == slopes.shape assert len(gamma.shape) + 1 == len(knot_spacings.shape) # assert that batch_shape is computed properly assert distr.batch_shape == batch_shape # assert that shapes of original parameters are correct assert distr.b.shape == slopes.shape assert distr.knot_positions.shape == knot_spacings.shape # assert that the shape of crps is correct assert distr.crps(target).shape == batch_shape # assert that the quantile shape is correct when computing the quantile values at knot positions - used for a_tilde assert distr.quantile(knot_spacings, axis=-2).shape == ( *batch_shape, num_pieces, ) # assert that the samples and the quantile values shape when num_samples is None is correct samples = distr.sample() assert samples.shape == batch_shape assert distr.quantile(samples).shape == batch_shape # assert that the samples and the quantile values shape when num_samples is not None is correct samples = distr.sample(num_samples) assert samples.shape == (num_samples, *batch_shape) assert distr.quantile(samples, axis=0).shape == (num_samples, *batch_shape)
def test_piecewise_linear( gamma: float, slopes: np.ndarray, knot_spacings: np.ndarray, hybridize: bool, ) -> None: ''' Test to check that minimizing the CRPS recovers the quantile function ''' num_samples = 500 # use a few samples for timeout failure gammas = mx.nd.zeros((num_samples, )) + gamma slopess = mx.nd.zeros((num_samples, len(slopes))) + mx.nd.array(slopes) knot_spacingss = mx.nd.zeros( (num_samples, len(knot_spacings))) + mx.nd.array(knot_spacings) pwl_sqf = PiecewiseLinear(gammas, slopess, knot_spacingss) samples = pwl_sqf.sample() # Parameter initialization gamma_init = gamma - START_TOL_MULTIPLE * TOL * gamma slopes_init = slopes - START_TOL_MULTIPLE * TOL * slopes knot_spacings_init = knot_spacings # We perturb knot spacings such that even after the perturbation they sum to 1. mid = len(slopes) // 2 knot_spacings_init[:mid] = (knot_spacings[:mid] - START_TOL_MULTIPLE * TOL * knot_spacings[:mid]) knot_spacings_init[mid:] = (knot_spacings[mid:] + START_TOL_MULTIPLE * TOL * knot_spacings[mid:]) init_biases = [gamma_init, slopes_init, knot_spacings_init] # check if it returns original parameters of mapped gamma_hat, slopes_hat, knot_spacings_hat = maximum_likelihood_estimate_sgd( PiecewiseLinearOutput(len(slopes)), samples, init_biases=init_biases, hybridize=hybridize, learning_rate=PositiveFloat(0.01), num_epochs=PositiveInt(20), ) # Since the problem is highly non-convex we may not be able to recover the exact parameters # Here we check if the estimated parameters yield similar function evaluations at different quantile levels. quantile_levels = np.arange(0.1, 1.0, 0.1) # create a LinearSplines instance with the estimated parameters to have access to .quantile pwl_sqf_hat = PiecewiseLinear( mx.nd.array(gamma_hat), mx.nd.array(slopes_hat).expand_dims(axis=0), mx.nd.array(knot_spacings_hat).expand_dims(axis=0), ) # Compute quantiles with the estimated parameters quantiles_hat = np.squeeze( pwl_sqf_hat.quantile(mx.nd.array(quantile_levels).expand_dims(axis=0), axis=1).asnumpy()) # Compute quantiles with the original parameters # Since params is replicated across samples we take only the first entry quantiles = np.squeeze( pwl_sqf.quantile( mx.nd.array(quantile_levels).expand_dims(axis=0).repeat( axis=0, repeats=num_samples), axis=1, ).asnumpy()[0, :]) for ix, (quantile, quantile_hat) in enumerate(zip(quantiles, quantiles_hat)): assert np.abs(quantile_hat - quantile) < TOL * quantile, ( f"quantile level {quantile_levels[ix]} didn't match:" f" " f"q = {quantile}, q_hat = {quantile_hat}")