def elemwise_log_mul_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log, emit_generic=True) linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True)
def named_form(lhs, rhs): init_result = linalg.InitTensorOp([4, 8], f32) # Check for the named form with custom format # CHECK: linalg.elemwise_unary # CHECK-SAME: cast = #linalg.type_fn<cast_signed> # CHECK-SAME: fun = #linalg.unary_fn<exp> # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result]) # CHECK: linalg.elemwise_binary # CHECK-SAME: cast = #linalg.type_fn<cast_unsigned> # CHECK-SAME: fun = #linalg.binary_fn<mul> # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) # CHECK: return binary_result = linalg.elemwise_binary( lhs, rhs, outs=[init_result.result], fun=BinaryFn.mul, cast=TypeFn.cast_unsigned) return unary_result, binary_result
def elemwise_exp_add_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
def elemwise_log_mul_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
def elemwise_exp_add_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out]) linalg.elemwise_binary(out, rhs, outs=[out])