Beispiel #1
0
    def __matmul__(self, right: 'DNDArray') -> 'DNDArray':
        left = self
        assert left.block_size == right.block_size
        assert left.n_cols == right.n_rows
        assert left.n_block_cols == right.n_block_rows

        n_rows = left.n_rows
        n_cols = right.n_cols
        block_size = left.block_size

        n_block_rows = left.n_block_rows
        n_block_inner = left.n_block_cols
        n_block_cols = right.n_block_cols
        n_multiplies = n_block_rows * n_block_cols * n_block_inner

        o = hl.utils.range_table(n_multiplies, n_partitions=n_multiplies)
        o = o.key_by(
            r=o.idx // (n_block_cols * n_block_inner),
            c=(o.idx % (n_block_cols * n_block_inner)) // n_block_inner,
            k=o.idx % n_block_inner
        ).select()
        o = o._key_by_assert_sorted('r', 'c', 'k')
        o = o._key_by_assert_sorted('r', 'k', 'c')
        o = o.annotate(left=left.m[o.r, o.k].block)
        o = o._key_by_assert_sorted('k', 'c', 'r')
        o = o.annotate(right=right.m[o.k, o.c].block)
        o = o.annotate(product=o.left @ o.right)

        # FIXME: use ndarray sum / fma
        def ndarray_to_array(ndarray):
            return hl.rbind(
                ndarray.shape[0],
                ndarray.shape[1],
                lambda n_rows, n_cols: hl.range(hl.int(n_rows * n_cols)).map(
                    lambda absolute: o.product[absolute % n_rows, absolute // n_rows]))
        o = o.annotate(shape=o.product.shape,
                       product=ndarray_to_array(o.product))
        o = o._key_by_assert_sorted('r', 'c', 'k')
        o = o._key_by_assert_sorted('r', 'c')

        import hail.methods.misc as misc
        misc.require_key(o, 'collect_by_key')
        import hail.ir as ir

        o = Table(ir.TableAggregateByKey(
            o._tir,
            hl.struct(
                shape=hl.agg.take(o.shape, 1)[0],
                block=hl.agg.array_sum(o.product))._ir))
        o = o.annotate(block=hl.nd.from_column_major(o.block, o.shape))
        o = o.select('block')
        o = o.select_globals(
            r_field='r',
            c_field='c',
            n_rows=n_rows,
            n_cols=n_cols,
            n_block_rows=n_block_rows,
            n_block_cols=n_block_cols,
            block_size=block_size)
        return DNDArray(o)
Beispiel #2
0
    def _block_inner_product(self,
                             right: 'DNDArray',
                             block_product: Callable[[Expression, Expression], Expression],
                             block_aggregate: Callable[[Expression], Expression]
                             ) -> 'DNDArray':
        left = self
        assert left.block_size == right.block_size
        assert left.n_cols == right.n_rows
        assert left.n_block_cols == right.n_block_rows

        n_rows = left.n_rows
        n_cols = right.n_cols
        block_size = left.block_size

        n_block_rows = left.n_block_rows
        n_block_inner = left.n_block_cols
        n_block_cols = right.n_block_cols
        n_multiplies = n_block_rows * n_block_cols * n_block_inner

        o = hl.utils.range_table(n_multiplies, n_partitions=n_multiplies)
        o = o.key_by(
            r=o.idx // (n_block_cols * n_block_inner),
            c=(o.idx % (n_block_cols * n_block_inner)) // n_block_inner,
            k=o.idx % n_block_inner
        ).select()
        o = o._key_by_assert_sorted('r', 'c', 'k')
        o = o._key_by_assert_sorted('r', 'k', 'c')
        o = o.annotate(left=left.m[o.r, o.k].block)
        o = o._key_by_assert_sorted('k', 'c', 'r')
        o = o.annotate(right=right.m[o.k, o.c].block)

        o = o.annotate(product=block_product(o.left, o.right))
        o = o._key_by_assert_sorted('r', 'c', 'k')
        o = o._key_by_assert_sorted('r', 'c')

        import hail.methods.misc as misc
        misc.require_key(o, 'collect_by_key')
        import hail.ir as ir

        o = Table(ir.TableAggregateByKey(
            o._tir,
            hl.struct(block=block_aggregate(o.product))._ir))
        o = o.select('block')
        o = o.select_globals(
            n_rows=n_rows,
            n_cols=n_cols,
            n_block_rows=n_block_rows,
            n_block_cols=n_block_cols,
            block_size=block_size)
        return DNDArray(o)