def test_get_smooth_mask_correct(labels): dist = Binned(**COMMON_KWARGS, label_smoothing=0.2) binned = Binned(**COMMON_KWARGS) labels = labels.expand_dims(-1) mask = dist._get_mask(labels) assert np.allclose(mask.asnumpy(), binned._get_mask(labels).asnumpy()) smooth_mask = dist._smooth_mask(mx.nd, mask, alpha=mx.nd.array([0.2])) # check smooth mask adds to one assert np.allclose(smooth_mask.asnumpy().sum(axis=-1), np.ones(2)) # check smooth mask peaks same assert np.allclose( np.argmax(smooth_mask.asnumpy(), axis=-1), np.argmax(mask.asnumpy(), axis=-1), ) # check smooth mask mins correct assert np.allclose( smooth_mask.asnumpy().min(axis=-1), np.ones(2) * 0.2 / 7 # alpha / K )
def test_smooth_mask_adds_to_one(K, alpha): bin_log_probs = mx.nd.log_softmax(mx.nd.ones(K)) bin_centers = mx.nd.arange(K) dist = Binned( bin_log_probs=bin_log_probs, bin_centers=bin_centers, label_smoothing=0.2, ) labels = mx.random.uniform(low=0, high=K, shape=(12, )).expand_dims(-1) mask = dist._get_mask(labels) smooth_mask = dist._smooth_mask(mx.nd, mask, alpha=mx.nd.array([alpha])) # check smooth mask adds to one assert np.allclose(smooth_mask.asnumpy().sum(axis=-1), np.ones(12), atol=1e-6)