def test_complex_slices(self): X = np.random.randn(21, 67, 53) shard_sizes = [21, 16, 11] X_sharded = BigMatrix("test_5", shape=X.shape, shard_sizes=shard_sizes) shard_matrix(X_sharded, X) assert (np.all(X[:, :16, :11] == X_sharded.submatrix(0, 0, 0).numpy())) assert (np.all(X[:, 64:67, 44:53] == X_sharded.submatrix(0, 4, 4).numpy()))
def test_simple_slices(self): X = np.random.randn(128, 128) shard_sizes = [32, 32] X_sharded = BigMatrix("test_3", shape=X.shape, shard_sizes=shard_sizes) shard_matrix(X_sharded, X) assert(np.all(X[0:64] == X_sharded.submatrix([2]).numpy())) assert(np.all(X[64:128] == X_sharded.submatrix([2, None]).numpy())) assert(np.all(X[:, 0:96] == X_sharded.submatrix(None, [0, 3]).numpy())) assert(np.all(X[:, 96:128] == X_sharded.submatrix( None, [3, None]).numpy()))
def test_multiple_shard_index_get(self): X = np.random.randn(128, 128) shard_sizes = [64, 64] X_sharded = BigMatrix("test_2", shape=X.shape, shard_sizes=shard_sizes) shard_matrix(X_sharded, X) assert (np.all(X[0:64, 0:64] == X_sharded.submatrix(0).get_block(0))) assert (np.all(X[64:128, 64:128] == X_sharded.submatrix(1, 1).get_block())) assert (np.all(X[0:64, 64:128] == X_sharded.submatrix(0, 1).get_block())) assert (np.all(X[64:128, 0:64] == X_sharded.submatrix(None, 0).get_block(1)))
def test_step_slices(self): X = np.random.randn(128, 128) shard_sizes = [16, 16] X_sharded = BigMatrix("test_4", shape=X.shape, shard_sizes=shard_sizes) shard_matrix(X_sharded, X) assert (np.all( X[::32] == X_sharded.submatrix([None, None, 2]).numpy()[::16])) assert (np.all( X[16::32] == X_sharded.submatrix([1, None, 2]).numpy()[::16])) assert (np.all(X[:, 0:96:64] == X_sharded.submatrix( None, [0, 6, 4]).numpy()[:, ::16])) assert (np.all(X[:, 96:128:64] == X_sharded.submatrix( None, [6, 8, 4]).numpy()[:, ::16]))
def test_single_shard_index_get(self): X = np.random.randn(128, 128) X_sharded = BigMatrix("test_0", shape=X.shape, shard_sizes=X.shape) shard_matrix(X_sharded, X) X_sharded_local = X_sharded.submatrix(0, 0).get_block() assert (np.all(X_sharded_local == X))
def test_single_shard_index_put(self): X = np.random.randn(128, 128) X_sharded = BigMatrix("test_1", shape=X.shape, shard_sizes=X.shape) X_sharded.submatrix(0, 0).put_block(X) assert (np.all(X_sharded.numpy() == X))