def tril(a): if len(a.shape) != 2: raise Exception("input must have dimension 2, but len(a.shape) is " + str(len(a.shape))) dist_array = DistArray(a.dtype, a.shape) for i in range(a.num_blocks[0]): for j in range(a.num_blocks[1]): if i > j: dist_array.blocks[i, j] = single.copy(a.blocks[i, j]) elif i == j: dist_array.blocks[i, j] = single.triu(a.blocks[i, j]) else: dist_array.blocks[i, j] = single.zeros([block_size, block_size]) return dist_array
def copy(a): dist_array = DistArray(a.dtype, a.shape) for index in np.ndindex(*dist_array.num_blocks): dist_array.blocks[index] = single.copy(a.blocks[index]) return dist_array