def test_get_smooth_mask_correct(labels):
    dist = Binned(**COMMON_KWARGS, label_smoothing=0.2)
    binned = Binned(**COMMON_KWARGS)

    labels = labels.expand_dims(-1)

    mask = dist._get_mask(labels)

    assert np.allclose(mask.asnumpy(), binned._get_mask(labels).asnumpy())

    smooth_mask = dist._smooth_mask(mx.nd, mask, alpha=mx.nd.array([0.2]))

    # check smooth mask adds to one
    assert np.allclose(smooth_mask.asnumpy().sum(axis=-1), np.ones(2))

    # check smooth mask peaks same
    assert np.allclose(
        np.argmax(smooth_mask.asnumpy(), axis=-1),
        np.argmax(mask.asnumpy(), axis=-1),
    )

    # check smooth mask mins correct
    assert np.allclose(
        smooth_mask.asnumpy().min(axis=-1),
        np.ones(2) * 0.2 / 7  # alpha / K
    )
def test_smooth_mask_adds_to_one(K, alpha):
    bin_log_probs = mx.nd.log_softmax(mx.nd.ones(K))
    bin_centers = mx.nd.arange(K)

    dist = Binned(
        bin_log_probs=bin_log_probs,
        bin_centers=bin_centers,
        label_smoothing=0.2,
    )

    labels = mx.random.uniform(low=0, high=K, shape=(12, )).expand_dims(-1)
    mask = dist._get_mask(labels)
    smooth_mask = dist._smooth_mask(mx.nd, mask, alpha=mx.nd.array([alpha]))

    # check smooth mask adds to one
    assert np.allclose(smooth_mask.asnumpy().sum(axis=-1),
                       np.ones(12),
                       atol=1e-6)