def __matmul__(self, right: 'DNDArray') -> 'DNDArray': left = self assert left.block_size == right.block_size assert left.n_cols == right.n_rows assert left.n_block_cols == right.n_block_rows n_rows = left.n_rows n_cols = right.n_cols block_size = left.block_size n_block_rows = left.n_block_rows n_block_inner = left.n_block_cols n_block_cols = right.n_block_cols n_multiplies = n_block_rows * n_block_cols * n_block_inner o = hl.utils.range_table(n_multiplies, n_partitions=n_multiplies) o = o.key_by( r=o.idx // (n_block_cols * n_block_inner), c=(o.idx % (n_block_cols * n_block_inner)) // n_block_inner, k=o.idx % n_block_inner ).select() o = o._key_by_assert_sorted('r', 'c', 'k') o = o._key_by_assert_sorted('r', 'k', 'c') o = o.annotate(left=left.m[o.r, o.k].block) o = o._key_by_assert_sorted('k', 'c', 'r') o = o.annotate(right=right.m[o.k, o.c].block) o = o.annotate(product=o.left @ o.right) # FIXME: use ndarray sum / fma def ndarray_to_array(ndarray): return hl.rbind( ndarray.shape[0], ndarray.shape[1], lambda n_rows, n_cols: hl.range(hl.int(n_rows * n_cols)).map( lambda absolute: o.product[absolute % n_rows, absolute // n_rows])) o = o.annotate(shape=o.product.shape, product=ndarray_to_array(o.product)) o = o._key_by_assert_sorted('r', 'c', 'k') o = o._key_by_assert_sorted('r', 'c') import hail.methods.misc as misc misc.require_key(o, 'collect_by_key') import hail.ir as ir o = Table(ir.TableAggregateByKey( o._tir, hl.struct( shape=hl.agg.take(o.shape, 1)[0], block=hl.agg.array_sum(o.product))._ir)) o = o.annotate(block=hl.nd.from_column_major(o.block, o.shape)) o = o.select('block') o = o.select_globals( r_field='r', c_field='c', n_rows=n_rows, n_cols=n_cols, n_block_rows=n_block_rows, n_block_cols=n_block_cols, block_size=block_size) return DNDArray(o)
def _block_inner_product(self, right: 'DNDArray', block_product: Callable[[Expression, Expression], Expression], block_aggregate: Callable[[Expression], Expression] ) -> 'DNDArray': left = self assert left.block_size == right.block_size assert left.n_cols == right.n_rows assert left.n_block_cols == right.n_block_rows n_rows = left.n_rows n_cols = right.n_cols block_size = left.block_size n_block_rows = left.n_block_rows n_block_inner = left.n_block_cols n_block_cols = right.n_block_cols n_multiplies = n_block_rows * n_block_cols * n_block_inner o = hl.utils.range_table(n_multiplies, n_partitions=n_multiplies) o = o.key_by( r=o.idx // (n_block_cols * n_block_inner), c=(o.idx % (n_block_cols * n_block_inner)) // n_block_inner, k=o.idx % n_block_inner ).select() o = o._key_by_assert_sorted('r', 'c', 'k') o = o._key_by_assert_sorted('r', 'k', 'c') o = o.annotate(left=left.m[o.r, o.k].block) o = o._key_by_assert_sorted('k', 'c', 'r') o = o.annotate(right=right.m[o.k, o.c].block) o = o.annotate(product=block_product(o.left, o.right)) o = o._key_by_assert_sorted('r', 'c', 'k') o = o._key_by_assert_sorted('r', 'c') import hail.methods.misc as misc misc.require_key(o, 'collect_by_key') import hail.ir as ir o = Table(ir.TableAggregateByKey( o._tir, hl.struct(block=block_aggregate(o.product))._ir)) o = o.select('block') o = o.select_globals( n_rows=n_rows, n_cols=n_cols, n_block_rows=n_block_rows, n_block_cols=n_block_cols, block_size=block_size) return DNDArray(o)