def test_computes_mean_with_weights_and_mask(self, backend_name): with backend.use_backend(backend_name): inputs = [np.array([1, 2, 4])] targets = [np.array([1, 0, 0])] weights = [np.array([10, 4, 1])] mean = trax.masked_mean(inputs, targets, weights, mask_id=1) onp.testing.assert_allclose(mean, 2.4)
def test_computes_mean_with_mask(self): inputs = [np.array([1, 2, 3])] targets = [np.array([1, 0, 0])] weights = [1] with backend.use_backend("numpy"): mean = trax.masked_mean(inputs, targets, weights, mask_id=1) np.testing.assert_allclose(mean, 2.5)
def test_computes_mean_with_weights(self, backend_name): with backend.use_backend(backend_name): inputs = [np.array([1, 2, 3])] targets = [np.zeros(3)] weights = [np.array([3, 1, 0])] mean = trax.masked_mean(inputs, targets, weights) onp.testing.assert_allclose(mean, 1.25)
def test_computes_basic_mean(self): inputs = [np.array([1, 2, 3])] targets = [np.zeros(3)] weights = [1] with backend.use_backend("numpy"): mean = trax.masked_mean(inputs, targets, weights) np.testing.assert_allclose(mean, 2)