def transform_class_member_access( self, class_member_access: ir0.ClassMemberAccess): if (isinstance(class_member_access.expr, ir0.TemplateInstantiation) and isinstance(class_member_access.expr.template_expr, ir0.AtomicTypeLiteral) and class_member_access.expr.template_expr.cpp_type in self.movable_arg_indexes_by_template_name): assert isinstance(class_member_access.expr.template_expr.expr_type, ir0.TemplateType) # F<X, Y>::type -> F<X>::type<Y> (if X is used in non-trivial patterns and Y isn't) args = self.transform_exprs(class_member_access.expr.args, class_member_access.expr) movable_arg_indexes = self.movable_arg_indexes_by_template_name[ class_member_access.expr.template_expr.cpp_type] template_instantiation_arg_exprs = [ arg for index, arg in enumerate(args) if index not in movable_arg_indexes ] typedef_instantiation_arg_exprs = [ arg for index, arg in enumerate(args) if index in movable_arg_indexes ] typedef_instantiation_arg_types = [ arg_type for index, arg_type in enumerate( class_member_access.expr.template_expr.expr_type.args) if index in movable_arg_indexes ] template_instantiation = ir0.TemplateInstantiation( template_expr=ir0.AtomicTypeLiteral.for_nonlocal_template( cpp_type=class_member_access.expr.template_expr.cpp_type, args=[ arg for index, arg in enumerate(class_member_access.expr.template_expr. expr_type.args) if index not in movable_arg_indexes ], is_metafunction_that_may_return_error=class_member_access. expr.template_expr.is_metafunction_that_may_return_error, may_be_alias=True), args=template_instantiation_arg_exprs, instantiation_might_trigger_static_asserts=class_member_access. expr.instantiation_might_trigger_static_asserts) new_class_member_access = ir0.ClassMemberAccess( class_type_expr=template_instantiation, member_name=class_member_access.member_name, member_type=ir0.TemplateType(typedef_instantiation_arg_types)) return ir0.TemplateInstantiation( template_expr=new_class_member_access, args=typedef_instantiation_arg_exprs, instantiation_might_trigger_static_asserts=class_member_access. expr.instantiation_might_trigger_static_asserts) else: return super().transform_class_member_access(class_member_access)
def _metafunction_call(template_expr: ir0.Expr, args: Sequence[ir0.Expr], instantiation_might_trigger_static_asserts: bool, member_name: str, member_type: ir0.ExprType): return ir0.ClassMemberAccess(class_type_expr=ir0.TemplateInstantiation( template_expr=template_expr, args=args, instantiation_might_trigger_static_asserts= instantiation_might_trigger_static_asserts), member_name=member_name, member_type=member_type)
def transform_type_literal(self, type_literal: ir.AtomicTypeLiteral): if self.additional_typedef_args_in_current_template and type_literal.cpp_type in self.locals_to_instantiate: assert isinstance(type_literal.expr_type, ir0.TypeType) # X5 -> X5<T, U>, if X5 is defined by a local typedef and we're moving {X,U} to be typedef template args # instead of args of the template defn. return ir0.TemplateInstantiation(template_expr=ir.AtomicTypeLiteral.for_local(cpp_type=type_literal.cpp_type, expr_type=ir0.TemplateType(tuple(ir0.TemplateArgType(expr_type=arg.expr_type, is_variadic=arg.is_variadic) for arg in self.additional_typedef_args_in_current_template)), is_variadic=False), args=tuple(ir0.AtomicTypeLiteral.for_local(cpp_type=arg.name, expr_type=arg.expr_type, is_variadic=arg.is_variadic) for arg in self.additional_typedef_args_in_current_template), instantiation_might_trigger_static_asserts=False) else: return type_literal
def typedef_to_cpp(typedef: ir0.Typedef, enclosing_function_defn_args: List[ir0.TemplateArgDecl], writer: Writer): if typedef.expr.expr_type.kind == ir0.ExprKind.TEMPLATE: assert not typedef.template_args template_args = [ir0.TemplateArgDecl(expr_type=arg.expr_type, name=writer.new_id(), is_variadic=arg.is_variadic) for arg in typedef.expr.expr_type.args] typedef = ir0.Typedef(name=typedef.name, expr=ir0.TemplateInstantiation(template_expr=typedef.expr, args=[ir0.AtomicTypeLiteral.for_local(expr_type=arg.expr_type, cpp_type=arg.name, is_variadic=arg.is_variadic) for arg in template_args], # TODO: use static analysis to determine when it's # safe to set this to False. instantiation_might_trigger_static_asserts=True), description=typedef.description, template_args=template_args) assert typedef.expr.expr_type.kind == ir0.ExprKind.TYPE, typedef.expr.expr_type.kind name = typedef.name cpp_meta_expr = expr_to_cpp(typedef.expr, enclosing_function_defn_args, writer) if typedef.description: description = '// ' + typedef.description + '\n' else: description = '' if not typedef.template_args: writer.write_template_body_elem('''\ {description}using {name} = {cpp_meta_expr}; '''.format(**locals())) else: template_args_decl = ', '.join(template_arg_decl_to_cpp(arg) for arg in typedef.template_args) writer.write_template_body_elem('''\ {description}template <{template_args_decl}> using {name} = {cpp_meta_expr}; '''.format(**locals()))
lambda: ir0.PointerTypeExpr(type_literal('int')), lambda: ir0.ReferenceTypeExpr(type_literal('int')), lambda: ir0.RvalueReferenceTypeExpr(type_literal('int')), lambda: ir0.ConstTypeExpr(type_literal('int')), lambda: ir0.ArrayTypeExpr(type_literal('int')), lambda: ir0.FunctionTypeExpr(type_literal('int'), []), lambda: ir0.FunctionTypeExpr(type_literal('int'), [type_literal('float')]), lambda: ir0.ComparisonExpr(literal(1), literal(2), op='=='), lambda: ir0.Int64BinaryOpExpr(literal(1), literal(2), op='+'), lambda: ir0.BoolBinaryOpExpr(literal(True), literal(False), op='||'), lambda: ir0.NotExpr(literal(True)), lambda: ir0.UnaryMinusExpr(literal(1)), lambda: ir0.TemplateInstantiation( template_expr=ir0.AtomicTypeLiteral.for_nonlocal_template( cpp_type='std::vector', args=[], is_metafunction_that_may_return_error=False, may_be_alias=False), args=[], instantiation_might_trigger_static_asserts=False), lambda: ir0.TemplateInstantiation( template_expr=ir0.AtomicTypeLiteral.for_nonlocal_template( cpp_type='std::vector', args=[ir0.TemplateArgType(ir0.TypeType(), is_variadic=False)], is_metafunction_that_may_return_error=False, may_be_alias=False), args=[type_literal('int')], instantiation_might_trigger_static_asserts=False), lambda: ir0.ClassMemberAccess(class_type_expr=type_literal('MyClass'), member_name='value_type', member_type=ir0.TypeType()), ],
def tmp_instantiation(template_expr: ir0.Expr, args: List[ir0.Expr]): return ir0.TemplateInstantiation( template_expr, args, instantiation_might_trigger_static_asserts=False)
def _type_list_of(*args: ir0.Expr): return ir0.TemplateInstantiation( template_expr=GlobalLiterals.LIST, args=args, instantiation_might_trigger_static_asserts=False)
def template_instantiation_to_cpp(instantiation_expr: ir0.TemplateInstantiation, enclosing_function_defn_args: List[ir0.TemplateArgDecl], writer: Writer, omit_typename=False): args = instantiation_expr.args if instantiation_expr.instantiation_might_trigger_static_asserts and enclosing_function_defn_args and args: bound_variables = {arg_decl.name for arg_decl in enclosing_function_defn_args} assert bound_variables # TODO: We could avoid adding a param dependency in more cases by checking for references to local variables # that depend (directly or indirectly) on a param. if not any(arg.references_any_of(bound_variables) for arg in args): # All template arguments are (or might be) constants, we need to add a reference to a variable bound in this # function to prevent the instantiation from happening early, potentially triggering static asserts. arg_decl = _select_best_arg_decl_for_select1st(enclosing_function_defn_args) arg_index = _select_best_arg_expr_index_for_select1st(args) arg_to_replace = args[arg_index] is_variadic = is_expr_variadic(arg_to_replace) if arg_decl.expr_type.kind != ir0.ExprKind.TEMPLATE and arg_to_replace.expr_type.kind != ir0.ExprKind.TEMPLATE: # We use lambdas here just to make sure we collect code coverage of each "branch". They are not necessary. # Note that we use the *Type variants for variadic types too. That's ok, since e.g. # Select1stBoolType<b, Args> will be expanded as e.g. Select1stBoolType<b, Args>... so it's exactly what # we want in the variadic case too. select1st_variant = { (ir0.ExprKind.BOOL, ir0.ExprKind.BOOL): lambda: 'Select1stBoolBool', (ir0.ExprKind.BOOL, ir0.ExprKind.INT64): lambda: 'Select1stBoolInt64', (ir0.ExprKind.BOOL, ir0.ExprKind.TYPE): lambda: 'Select1stBoolType', (ir0.ExprKind.INT64, ir0.ExprKind.BOOL): lambda: 'Select1stInt64Bool', (ir0.ExprKind.INT64, ir0.ExprKind.INT64): lambda: 'Select1stInt64Int64', (ir0.ExprKind.INT64, ir0.ExprKind.TYPE): lambda: 'Select1stInt64Type', (ir0.ExprKind.TYPE, ir0.ExprKind.BOOL): lambda: 'Select1stTypeBool', (ir0.ExprKind.TYPE, ir0.ExprKind.INT64): lambda: 'Select1stTypeInt64', (ir0.ExprKind.TYPE, ir0.ExprKind.TYPE): lambda: 'Select1stTypeType', }[(arg_to_replace.expr_type.kind, arg_decl.expr_type.kind)]() else: # We need to define a new Select1st variant for the desired function type. select1st_variant = writer.new_id() forwarded_param_id = writer.new_id() template_param_decl1 = _type_to_template_param_declaration(expr_type=arg_to_replace.expr_type, is_variadic=is_variadic) template_param_decl2 = _type_to_template_param_declaration(expr_type=arg_decl.expr_type, is_variadic=arg_decl.is_variadic) select1st_variant_body_writer = TemplateElemWriter(writer.get_toplevel_writer()) if arg_to_replace.expr_type.kind in (ir0.ExprKind.BOOL, ir0.ExprKind.INT64): select1st_variant_body = ir0.ConstantDef(name='value', expr=ir0.AtomicTypeLiteral.for_local(cpp_type=forwarded_param_id, expr_type=arg_to_replace.expr_type, is_variadic=is_variadic)) constant_def_to_cpp(select1st_variant_body, enclosing_function_defn_args, select1st_variant_body_writer) else: replaced_type = arg_to_replace.expr_type assert replaced_type.kind in (ir0.ExprKind.TYPE, ir0.ExprKind.TEMPLATE) select1st_variant_body = ir0.Typedef(name='value', expr=ir0.AtomicTypeLiteral.for_local(cpp_type=forwarded_param_id, expr_type=replaced_type, is_variadic=is_variadic)) typedef_to_cpp(select1st_variant_body, enclosing_function_defn_args, select1st_variant_body_writer) select1st_variant_body_str = ''.join(select1st_variant_body_writer.strings) writer.write_template_body_elem(''' // Custom Select1st* template template <{template_param_decl1} {forwarded_param_id}, {template_param_decl2}> struct {select1st_variant} {{ {select1st_variant_body_str} }}; '''.format(**locals())) select1st_type = ir0.TemplateType(args=[ ir0.TemplateArgType(expr_type=arg_to_replace.expr_type, is_variadic=is_variadic), arg_decl]) select1st_instantiation = ir0.TemplateInstantiation(template_expr=ir0.AtomicTypeLiteral.for_local(cpp_type=select1st_variant, expr_type=select1st_type, is_variadic=False), args=[arg_to_replace, ir0.AtomicTypeLiteral.for_local(cpp_type=arg_decl.name, expr_type=arg_decl.expr_type, is_variadic=arg_decl.is_variadic)], instantiation_might_trigger_static_asserts=False) new_arg = ir0.ClassMemberAccess(class_type_expr=select1st_instantiation, member_name='value', member_type=arg_to_replace.expr_type) args = args[:arg_index] + (new_arg,) + args[arg_index + 1:] template_params = ', '.join(expr_to_cpp(arg, enclosing_function_defn_args, writer) for arg in args) if isinstance(instantiation_expr.template_expr, ir0.ClassMemberAccess): cpp_fun = class_member_access_to_cpp(instantiation_expr.template_expr, enclosing_function_defn_args, writer, omit_typename=omit_typename, parent_expr_is_template_instantiation=True) else: cpp_fun = expr_to_cpp(instantiation_expr.template_expr, enclosing_function_defn_args, writer) return '{cpp_fun}<{template_params}>'.format(**locals())