def build_matmul_tensors_func(func_name, m, k, n, dtype): lhs_type = RankedTensorType.get([m, k], dtype) rhs_type = RankedTensorType.get([k, n], dtype) result_type = RankedTensorType.get([m, n], dtype) # TODO: There should be a one-liner for this. func_type = FunctionType.get([lhs_type, rhs_type], [result_type]) _, entry = FuncOp(func_name, func_type) lhs, rhs = entry.arguments with InsertionPoint(entry): op = linalg.MatmulOp([lhs, rhs], results=[result_type]) # TODO: Implement support for SingleBlockImplicitTerminator block = op.regions[0].blocks.append() with InsertionPoint(block): linalg.YieldOp(values=[]) std.ReturnOp([op.result])
def testStructuredOpOnTensors(): with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() tensor_type = RankedTensorType.get((2, 3, 4), f32) with InsertionPoint(module.body): func = builtin.FuncOp(name="matmul_test", type=FunctionType.get( inputs=[tensor_type, tensor_type], results=[tensor_type])) with InsertionPoint(func.add_entry_block()): lhs, rhs = func.entry_block.arguments result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result std.ReturnOp([result]) # CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> print(module)
def testStructuredOpOnBuffers(): with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() memref_type = MemRefType.get((2, 3, 4), f32) with InsertionPoint.at_block_terminator(module.body): func = builtin.FuncOp(name="matmul_test", type=FunctionType.get(inputs=[ memref_type, memref_type, memref_type ], results=[])) with InsertionPoint(func.add_entry_block()): lhs, rhs, result = func.entry_block.arguments linalg.MatmulOp([lhs, rhs], outputs=[result]) std.ReturnOp([]) # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>) print(module)