def test_stratified_mean_equal1(self):
     pdist = torch.tensor([1, 1])
     actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
     desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]])
     mismatch = np.testing.assert_array_almost_equal(actual,
                                                     desired,
                                                     decimal=5)
     self.assertIsNone(mismatch)
示例#2
0
 def test_stratified_mean_equal2(self):
     pdist = torch.tensor([2, 2])
     actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
     desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
                             [1., 1., 1.]])
     mismatch = np.testing.assert_array_almost_equal(actual,
                                                     desired,
                                                     decimal=5)
     self.assertIsNone(mismatch)
 def test_stratified_mean_unequal_one_hot(self):
     pdist = torch.tensor([1, 3])
     y = torch.eye(2)[self.y]
     desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
                              [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
     actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
     desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
     mismatch = np.testing.assert_array_almost_equal(actual1,
                                                     desired1,
                                                     decimal=5)
     mismatch = np.testing.assert_array_almost_equal(actual2,
                                                     desired2,
                                                     decimal=5)
     self.assertIsNone(mismatch)