def test_where(): cond = tensor([[True, False], [False, True]], chunk_size=1) x = tensor([1, 2], chunk_size=1) y = tensor([3, 4], chunk_size=1) arr = where(cond, x, y) arr = tile(arr) assert len(arr.chunks) == 4 np.testing.assert_equal(arr.chunks[0].inputs[0].op.data, [[True]]) np.testing.assert_equal(arr.chunks[0].inputs[1].op.data, [1]) np.testing.assert_equal(arr.chunks[0].inputs[2].op.data, [3]) np.testing.assert_equal(arr.chunks[1].inputs[0].op.data, [[False]]) np.testing.assert_equal(arr.chunks[1].inputs[1].op.data, [2]) np.testing.assert_equal(arr.chunks[1].inputs[2].op.data, [4]) np.testing.assert_equal(arr.chunks[2].inputs[0].op.data, [[False]]) np.testing.assert_equal(arr.chunks[2].inputs[1].op.data, [1]) np.testing.assert_equal(arr.chunks[2].inputs[2].op.data, [3]) np.testing.assert_equal(arr.chunks[3].inputs[0].op.data, [[True]]) np.testing.assert_equal(arr.chunks[3].inputs[1].op.data, [2]) np.testing.assert_equal(arr.chunks[3].inputs[2].op.data, [4]) with pytest.raises(ValueError): where(cond, x) x = arange(9.).reshape(3, 3) y = where(x < 5, x, -1) assert y.dtype == np.float64
def testWhereExecution(self): raw_cond = np.random.randint(0, 2, size=(4, 4), dtype='?') raw_x = np.random.rand(4, 1) raw_y = np.random.rand(4, 4) cond, x, y = tensor(raw_cond, chunk_size=2), tensor( raw_x, chunk_size=2), tensor(raw_y, chunk_size=2) arr = where(cond, x, y) res = self.executor.execute_tensor(arr, concat=True) self.assertTrue( np.array_equal(res[0], np.where(raw_cond, raw_x, raw_y))) raw_cond = sps.csr_matrix( np.random.randint(0, 2, size=(4, 4), dtype='?')) raw_x = sps.random(4, 1, density=.1) raw_y = sps.random(4, 4, density=.1) cond, x, y = tensor(raw_cond, chunk_size=2), tensor( raw_x, chunk_size=2), tensor(raw_y, chunk_size=2) arr = where(cond, x, y) res = self.executor.execute_tensor(arr, concat=True)[0] self.assertTrue( np.array_equal( res.toarray(), np.where(raw_cond.toarray(), raw_x.toarray(), raw_y.toarray())))
def testWhere(self): cond = tensor([[True, False], [False, True]], chunk_size=1) x = tensor([1, 2], chunk_size=1) y = tensor([3, 4], chunk_size=1) arr = where(cond, x, y) arr = arr.tiles() self.assertEqual(len(arr.chunks), 4) self.assertTrue( np.array_equal(arr.chunks[0].inputs[0].op.data, [[True]])) self.assertTrue(np.array_equal(arr.chunks[0].inputs[1].op.data, [1])) self.assertTrue(np.array_equal(arr.chunks[0].inputs[2].op.data, [3])) self.assertTrue( np.array_equal(arr.chunks[1].inputs[0].op.data, [[False]])) self.assertTrue(np.array_equal(arr.chunks[1].inputs[1].op.data, [2])) self.assertTrue(np.array_equal(arr.chunks[1].inputs[2].op.data, [4])) self.assertTrue( np.array_equal(arr.chunks[2].inputs[0].op.data, [[False]])) self.assertTrue(np.array_equal(arr.chunks[2].inputs[1].op.data, [1])) self.assertTrue(np.array_equal(arr.chunks[2].inputs[2].op.data, [3])) self.assertTrue( np.array_equal(arr.chunks[3].inputs[0].op.data, [[True]])) self.assertTrue(np.array_equal(arr.chunks[3].inputs[1].op.data, [2])) self.assertTrue(np.array_equal(arr.chunks[3].inputs[2].op.data, [4])) with self.assertRaises(ValueError): where(cond, x) x = arange(9.).reshape(3, 3) y = where(x < 5, x, -1) self.assertEqual(y.dtype, np.float64)