def tril(a): if len(a.shape) != 2: raise Exception("input must have dimension 2, but len(a.shape) is " + str(len(a.shape))) dist_array = DistArray(a.dtype, a.shape) for i in range(a.num_blocks[0]): for j in range(a.num_blocks[1]): if i > j: dist_array.blocks[i, j] = single.copy(a.blocks[i, j]) elif i == j: dist_array.blocks[i, j] = single.triu(a.blocks[i, j]) else: dist_array.blocks[i, j] = single.zeros([block_size, block_size]) return dist_array
def zeros(shape, dtype): dist_array = DistArray(dtype, shape) for index in np.ndindex(*dist_array.num_blocks): dist_array.blocks[index] = single.zeros(dist_array.compute_block_shape(index)) return dist_array