def test_stratified_mean_equal1(self): pdist = torch.tensor([1, 1]) actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch)
def test_stratified_mean_equal2(self): pdist = torch.tensor([2, 2]) actual, _ = initializers.stratified_mean(self.x, self.y, pdist) desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.], [1., 1., 1.]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch)
def test_stratified_mean_unequal_one_hot(self): pdist = torch.tensor([1, 3]) y = torch.eye(2)[self.y] desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) actual1, actual2 = initializers.stratified_mean(self.x, y, pdist) desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]]) mismatch = np.testing.assert_array_almost_equal(actual1, desired1, decimal=5) mismatch = np.testing.assert_array_almost_equal(actual2, desired2, decimal=5) self.assertIsNone(mismatch)