def testArgwhereExecution(self): x = arange(6, chunk_size=2).reshape(2, 3) t = argwhere(x > 1) res = self.executor.execute_tensor(t, concat=True)[0] expected = np.argwhere(np.arange(6).reshape(2, 3) > 1) self.assertTrue(np.array_equal(res, expected))
def testArgwhere(self): cond = tensor([[True, False], [False, True]], chunk_size=1) indices = argwhere(cond) self.assertTrue(np.isnan(indices.shape[0])) self.assertEqual(indices.shape[1], 2) indices.tiles() self.assertEqual(indices.nsplits[1], (1, 1))
def testArgwhere(self): cond = tensor([[True, False], [False, True]], chunk_size=1) indices = argwhere(cond) self.assertTrue(np.isnan(indices.shape[0])) self.assertEqual(indices.shape[1], 2) self.assertEqual(calc_shape(indices), indices.shape) self.assertEqual( indices.op.calc_rough_shape( tuple(t.rough_shape for t in indices.inputs)), (4, 2)) self.assertEqual(indices.rough_nbytes, 4 * 2 * indices.dtype.itemsize) indices.tiles() self.assertEqual(indices.nsplits[1], (1, 1)) chunk = indices.chunks[0] self.assertTrue(np.isnan(calc_shape(chunk)[0])) self.assertEqual(calc_shape(chunk)[1], chunk.shape[1]) self.assertEqual(chunk.rough_shape, (2, 1))