def blockmatrix_irs(self): scalar_ir = ir.F64(2) vector_ir = ir.MakeArray([ir.F64(3), ir.F64(2)], hl.tarray(hl.tfloat64)) read = ir.BlockMatrixRead( ir.BlockMatrixNativeReader(resource('blockmatrix_example/0'))) add_two_bms = ir.BlockMatrixMap2( read, read, 'l', 'r', ir.ApplyBinaryPrimOp('+', ir.Ref('l'), ir.Ref('r')), "Union") negate_bm = ir.BlockMatrixMap( read, 'element', ir.ApplyUnaryPrimOp('-', ir.Ref('element')), False) sqrt_bm = ir.BlockMatrixMap( read, 'element', hl.sqrt(construct_expr(ir.Ref('element'), hl.tfloat64))._ir, False) persisted = ir.BlockMatrixRead(ir.BlockMatrixPersistReader('x', read)) scalar_to_bm = ir.ValueToBlockMatrix(scalar_ir, [1, 1], 1) col_vector_to_bm = ir.ValueToBlockMatrix(vector_ir, [2, 1], 1) row_vector_to_bm = ir.ValueToBlockMatrix(vector_ir, [1, 2], 1) broadcast_scalar = ir.BlockMatrixBroadcast(scalar_to_bm, [], [2, 2], 256) broadcast_col = ir.BlockMatrixBroadcast(col_vector_to_bm, [0], [2, 2], 256) broadcast_row = ir.BlockMatrixBroadcast(row_vector_to_bm, [1], [2, 2], 256) transpose = ir.BlockMatrixBroadcast(broadcast_scalar, [1, 0], [2, 2], 256) matmul = ir.BlockMatrixDot(broadcast_scalar, transpose) rectangle = ir.Literal(hl.tarray(hl.tint64), [0, 1, 5, 6]) band = ir.Literal(hl.ttuple(hl.tint64, hl.tint64), (-1, 1)) intervals = ir.Literal( hl.ttuple(hl.tarray(hl.tint64), hl.tarray(hl.tint64)), ([0, 1, 5, 6], [5, 6, 8, 9])) sparsify1 = ir.BlockMatrixSparsify(read, rectangle, ir.RectangleSparsifier) sparsify2 = ir.BlockMatrixSparsify(read, band, ir.BandSparsifier(True)) sparsify3 = ir.BlockMatrixSparsify(read, intervals, ir.RowIntervalSparsifier(True)) densify = ir.BlockMatrixDensify(read) pow_ir = (construct_expr(ir.Ref('l'), hl.tfloat64)**construct_expr( ir.Ref('r'), hl.tfloat64))._ir squared_bm = ir.BlockMatrixMap2(scalar_to_bm, scalar_to_bm, 'l', 'r', pow_ir, "NeedsDense") slice_bm = ir.BlockMatrixSlice( matmul, [slice(0, 2, 1), slice(0, 1, 1)]) return [ read, persisted, add_two_bms, negate_bm, sqrt_bm, scalar_to_bm, col_vector_to_bm, row_vector_to_bm, broadcast_scalar, broadcast_col, broadcast_row, squared_bm, transpose, sparsify1, sparsify2, sparsify3, densify, matmul, slice_bm ]
def test_parses(self): backend = Env.spark_backend('BlockMatrixIRTests.test_parses') bmir = hl.linalg.BlockMatrix.fill(1, 1, 0.0)._bmir backend.execute(ir.BlockMatrixWrite(bmir, ir.BlockMatrixPersistWriter('x', 'MEMORY_ONLY'))) persist = ir.BlockMatrixRead(ir.BlockMatrixPersistReader('x', bmir)) for x in (self.blockmatrix_irs() + [persist]): backend._parse_blockmatrix_ir(str(x)) backend.unpersist_block_matrix('x')