Ejemplo n.º 1
0
 def testGatherNumerical2(self):
     t = TensorBase(np.array([[47, 74, 44], [56, 9, 37]]))
     idx = TensorBase(np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0]]))
     dim = 0
     result = t.gather(dim=dim, index=idx)
     expexted = [[47, 74, 37], [56, 9, 44.], [47, 9, 44]]
     self.assertTrue(np.array_equal(result.data, np.array(expexted)))
Ejemplo n.º 2
0
 def test_gather_numerical_1(self):
     t = TensorBase(np.array([[65, 17], [14, 25], [76, 22]]))
     idx = TensorBase(np.array([[0], [1], [0]]))
     dim = 1
     result = t.gather(dim=dim, index=idx)
     self.assertTrue(
         np.array_equal(result.data, np.array([[65], [25], [76]])))