class ComputeShapeSignature: """ Here we use the base name as the suffix of the signature to avoid generating for in-place variants. """ def __init__(self, kernel_name: str, f: NativeFunction): self.__schema = LazyIrSchema(f.func) self.__dispatch_args = ', '.join( [a.decl() for a in dispatcher.arguments(f.func)]) self.__call_args = ", ".join([ f"{arg.name}" for arg in self.__schema.filtered_args(generator=True) ]) self.__kernel_name = kernel_name def __decl_suffix(self) -> str: return f"{self.__kernel_name}({self.__dispatch_args})" def __call_suffix(self) -> str: return f"{self.__kernel_name}({self.__call_args})" @property def shape_decl(self) -> str: return f"TORCH_API std::vector<Shape> compute_shape_{self.__decl_suffix()}" @property def shape_call(self) -> str: return f"torch::lazy::compute_shape_{self.__call_suffix()}"
def node_ctor_inputs(schema: LazyIrSchema) -> str: """ Produce a formatted string with the arguments as passed into the constructor of a node class. """ node_ctor_values = [ node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args() ] return ",\n ".join(node_ctor_values)
def gen_fallback_code(schema: LazyIrSchema, overload_name: str) -> str: """ Generate code that falls back to eager conditioned on a predicate """ fallback_args = ",\n ".join([str(arg.name) for arg in schema.filtered_args(generator=True)]) if len(overload_name): aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})" else: aten_op_str = f"ATEN_OP({schema.aten_name})" or_has_generator = "" if schema.generator_arg: # generators are always optional and there is never more than one, at least currently or_has_generator = f" || ({schema.generator_arg.name}.has_value() && {schema.generator_arg.name}->defined())" return f"""
def __call__(self, f: NativeFunction) -> List[str]: sig = kernel_signature(f, self.backend_index) metadata = self.backend_index.get_kernel(f) assert metadata is not None schema = LazyIrSchema(f.func) value_args = schema.filtered_args(values=True, scalars=False) lazy_tensor_decls_str = lazy_tensor_decls(value_args, self.tensor_class) node_ctor_input_str = node_ctor_inputs(schema) # Only generate shape/dtype fn for non-structured kernels, # since we just use the meta function for structured kernels if not f.structured and f.structured_delegate is None: shape_sig = ComputeShapeSignature(metadata.kernel, f) return ["\n".join([f"{shape_sig.shape_decl};"])] else: return []
def gen_shape_call(self, func: NativeFunction) -> str: metadata = self.backend_index.get_kernel(func) assert metadata is not None schema = LazyIrSchema(func.func) all_args = schema.filtered_args() returns_length = len(schema.returns) # call the meta kernel if it exists, to compute output shape/dtype for our IR if func.structured or func.structured_delegate is not None: meta_out = """std::vector<Shape> shapes{Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" if returns_length > 1: def this_shape(i: int) -> str: return f"Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())" shapes_str = ",".join( [this_shape(i) for i in range(returns_length)]) meta_out = "std::vector<Shape> shapes{" + shapes_str + "};" meta_str = f"""auto out_meta = at::meta::{schema.aten_name}({', '.join(str(a.name) for a in all_args)}); {meta_out}""" else: shape_sig = ComputeShapeSignature(metadata.kernel, func) meta_str = f""" auto shapes = {shape_sig.shape_call};""" meta_str += f""" TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});""" # Calculating which dimensions are symbolic func_schema_str = "aten::" + str(func.func) meta_str += f""" if(symbolicShapeEnabled()){{ std::vector<jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }}; char* schema_str = "{func_schema_str}"; applySymbolicShapesOnLT(schema_str, inputs, shapes); }} """ return meta_str
def __call__(self, func: NativeFunction) -> List[str]: sig = kernel_signature(func, self.backend_index) metadata = self.backend_index.get_kernel(func) assert metadata is not None schema = LazyIrSchema(func.func) value_args = schema.filtered_args(values=True, scalars=False) returns_length = len(schema.returns) fallback_str = "" if self.gen_forced_fallback_code: fallback_str = gen_fallback_code( schema, overload_name=func.func.name.overload_name) value_types_names = [ f"{a.name}" for a in value_args if not a.is_wrapped_scalar ] assert (len(value_types_names) > 0), "Code below assumes there is at least one tensor arg" get_device_str = f"""auto common_device = torch::lazy::GetBackendDevice({', '.join(value_types_names)}); TORCH_INTERNAL_ASSERT(common_device); """ lazy_tensor_decls_str = lazy_tensor_decls(value_args, self.tensor_class) node_ctor_input_str = node_ctor_inputs(schema) shape_str = self.gen_shape_call(func) node_str = f"""auto node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));""" first_tensor_name = value_types_names[0] bridge_str = """auto result = torch::lazy::CreateAtenFromLtcTensor( torch::lazy::LazyTensor::Create(std::move(node), *common_device));""" if returns_length > 1: bridge_str = f"""std::vector<{self.tensor_class}Ptr> lazy_tensors; for (int i = 0; i < {returns_length}; i++) {{ lazy_tensors.push_back(torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)); }} auto result = torch::lazy::TupleAtenFromLtcTensors<{returns_length}>(lazy_tensors);""" if schema.name.name.inplace or func.func.is_out_fn(): assert returns_length == 1, ( "We assumed there was no such case where an op is an in-place variant " f"and has tuple outputs, but got tuple of len {returns_length}." ) bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node); auto& result = {first_tensor_name};""" return [ f"""\ {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{ {fallback_str} TORCH_LAZY_FN_COUNTER("lazy::"); {get_device_str} {lazy_tensor_decls_str} {shape_str} {node_str} {bridge_str} return result; }};\n """ ]
def gen(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: # for now, we just want one IR class decl and soon after also the method defs # and we use the functional version not out/inplace. func = f.functional.func if isinstance( f, NativeFunctionsGroup) else f.func schema = LazyIrSchema(func) all_args = schema.filtered_args() value_args = schema.filtered_args(values=True, scalars=False) scalar_args = schema.filtered_args(values=False, scalars=True) node_ctor_args = ", ".join( [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]) scalar_initializers = ",\n ".join( [f"{a.name}({a.name})" for a in scalar_args]) comma_if_scalar_initializers = ",\n" if len( scalar_initializers) else "" scalar_decls = "\n ".join([ f"std::string {a.name};" if a.lazy_type.cpp_type() == "c10::string_view" else f"{a.lazy_type.cpp_type()} {a.name};" for a in scalar_args ]) scalar_hashes = ", ".join([f"{a.name}" for a in scalar_args]) base_ctor_value_args_list = [] optional_values = [] for arg in value_args: if isinstance(arg.lazy_type, BaseCType) or isinstance( arg.lazy_type, VectorCType): base_ctor_value_args_list.append(f"{arg.name}") elif isinstance(arg.lazy_type, OptionalCType): base_ctor_value_args_list.append( f"{arg.name}.value_or(kNullValue)") optional_values.append(arg.name) else: raise AssertionError( f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" ) base_ctor_value_args = ", ".join(base_ctor_value_args_list) has_optional_decls = "\n ".join( [f"bool has_{value}: 1;" for value in optional_values]) has_optional_defs = "\n ".join( [f"has_{value} = !!{value};" for value in optional_values]) members_to_string = [] for arg in scalar_args: if isinstance(arg.lazy_type, OptionalCType): members_to_string.append(f"""if ({arg.name}.has_value()) {{ ss << ", {arg.name}=" << {arg.name}.value(); }} else {{ ss << ", {arg.name}=null"; }}""") else: members_to_string.append( f'ss << ", {arg.name}=" << {arg.name};') members_to_string_str = "\n ".join(members_to_string) return [ f"""\ class {schema.node_name} : public {self.node_base} {{ public: {schema.node_name}({node_ctor_args}, std::vector<Shape>&& shapes) : {self.node_base}(torch::lazy::OpKind({aten_symbol(schema)}), {{{base_ctor_value_args}}}, std::move(shapes), /* num_outputs */ {len(func.returns)}, torch::lazy::MHash({scalar_hashes})){comma_if_scalar_initializers} {scalar_initializers} {{ {has_optional_defs} }} std::string ToString() const override {{ std::stringstream ss; ss << {self.node_base}::ToString(); {members_to_string_str} return ss.str(); }} {self.lowering_function(f)} {scalar_decls} {has_optional_decls} }}; """, ]
def __call__(self, func: NativeFunction) -> List[str]: sig = kernel_signature(func, self.backend_index) metadata = self.backend_index.get_kernel(func) assert metadata is not None schema = LazyIrSchema(func.func) all_args = schema.filtered_args() value_args = schema.filtered_args(values=True, scalars=False) returns_length = len(schema.returns) fallback_str = gen_fallback_code(schema, overload_name=func.func.name.overload_name) value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] assert len(value_types_names) > 0, "Code below assumes there is at least one tensor arg" get_device_str = f"""auto common_device = torch::lazy::GetBackendDevice({', '.join(value_types_names)}); TORCH_INTERNAL_ASSERT(common_device); """ lazy_tensor_decls_str = lazy_tensor_decls(value_args, self.tensor_class) node_ctor_input_str = node_ctor_inputs(schema) # call the meta kernel if it exists, to compute output shape/dtype for our IR if func.structured or func.structured_delegate is not None: meta_out = """std::vector<Shape> shapes{Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" if returns_length > 1: def this_shape(i: int) -> str: return f"Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())" shapes_str = ','.join([this_shape(i) for i in range(returns_length)]) meta_out = "std::vector<Shape> shapes{" + shapes_str + "};" meta_str = f"""auto out_meta = at::meta::{schema.aten_name}({', '.join(str(a.name) for a in all_args)}); {meta_out}""" else: shape_sig = ComputeShapeSignature(metadata.kernel, func) meta_str = f""" auto shapes = {shape_sig.shape_call};""" meta_str += f""" TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});""" node_str = f"""auto node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));""" first_tensor_name = value_types_names[0] bridge_str = """auto result = torch::lazy::CreateAtenFromLtcTensor( torch::lazy::LazyTensor::Create(std::move(node), *common_device));""" if returns_length > 1: bridge_str = f"""std::vector<{self.tensor_class}Ptr> lazy_tensors; for (int i = 0; i < {returns_length}; i++) {{ lazy_tensors.push_back(torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)); }} auto result = torch::lazy::TupleAtenFromLtcTensors<{returns_length}>(lazy_tensors);""" if schema.name.name.inplace or func.func.is_out_fn(): assert returns_length == 1, "We assumed there was no such case where an op is an in-place variant " \ f"and has tuple outputs, but got tuple of len {returns_length}." bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node); auto& result = {first_tensor_name};""" return [f"""\ {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{ {fallback_str} TORCH_LAZY_FN_COUNTER("lazy::"); {get_device_str} {lazy_tensor_decls_str} {meta_str} {node_str} {bridge_str} return result; }};\n """]