def _unpack_if_variable( expr: ir.Expr, var_names: Set[str], literal_expr_by_unique_name: MutableMapping[str, ir.AtomicTypeLiteral]): if isinstance(expr, ir.VariadicTypeExpansion): return ListExpansion( _unpack_if_variable(expr.expr, var_names, literal_expr_by_unique_name)) if isinstance(expr, ir.AtomicTypeLiteral) and expr.cpp_type in var_names: # We keep track of the expr so that we can re-pack this later. # If there are multiple they must be the same. if expr.cpp_type in literal_expr_by_unique_name: assert expr == literal_expr_by_unique_name[ expr.cpp_type], '%s vs %s. Detailed:\n%s\n-- vs --\n%s' % ( expr_to_cpp_simple(expr), expr_to_cpp_simple( literal_expr_by_unique_name[expr.cpp_type]), ir_to_string(expr), ir_to_string(literal_expr_by_unique_name[expr.cpp_type])) else: literal_expr_by_unique_name[expr.cpp_type] = expr return expr.cpp_type else: return expr
def _determine_template_arg_indexes_that_can_be_moved_to_typedef_args( template_defn: ir0.TemplateDefn): arguments_with_non_trivial_patterns = set() contains_only_simple_typedefs = True for specialization in template_defn.get_all_definitions(): contains_only_simple_typedefs &= all( isinstance(elem, ir0.Typedef) and not elem.template_args for elem in specialization.body) if specialization.patterns: for arg_decl, pattern in zip(template_defn.args, specialization.patterns): if not _is_trivial_pattern(arg_decl, pattern): arguments_with_non_trivial_patterns.add(arg_decl.name) if len(template_defn.args) != len(specialization.patterns): assert template_defn.args[ -1].is_variadic, 'Template defn args: %s, patterns: %s' % ( {arg.name for arg in template_defn.args}, [ expr_to_cpp_simple(expr) for expr in specialization.patterns ]) arguments_with_non_trivial_patterns.add( template_defn.args[-1].name) if contains_only_simple_typedefs: if len(arguments_with_non_trivial_patterns) == 0: # So that there's always at least 1 template argument. arguments_with_non_trivial_patterns.add(template_defn.args[0].name) return { arg_index for arg_index, arg in enumerate(template_defn.args) if arg.name not in arguments_with_non_trivial_patterns } else: return set()
def _unify(initial_exprs: Tuple[ir.Expr, ...], local_var_definitions: Mapping[str, ir.Expr], patterns: Tuple[ir.Expr, ...], expr_variables: Set[str], pattern_variables: Set[str], identifier_generator: Iterator[str], verbose: bool) -> UnificationResult: # We need to replace local literals before doing the unification, to avoid assuming that e.g. T in an expr # is equal to T in a pattern just because they have the same name. lhs_type_literal_names = set(local_var_definitions.keys()) for expr in itertools.chain(initial_exprs, local_var_definitions.values()): for expr_literal in expr.free_vars: lhs_type_literal_names.add(expr_literal.cpp_type) unique_var_name_by_expr_type_literal_name = bidict({lhs_type_literal_name: next(identifier_generator) for lhs_type_literal_name in lhs_type_literal_names}) unique_var_name_by_pattern_type_literal_name = bidict({pattern_literal.cpp_type: next(identifier_generator) for pattern in patterns for pattern_literal in pattern.free_vars}) unique_var_names = set() for expr_var_name, unique_var_name in unique_var_name_by_expr_type_literal_name.items(): if expr_var_name in expr_variables or expr_var_name in local_var_definitions: unique_var_names.add(unique_var_name) for pattern_var_name, unique_var_name in unique_var_name_by_pattern_type_literal_name.items(): if pattern_var_name in pattern_variables: unique_var_names.add(unique_var_name) literal_expr_by_unique_name: Dict[str, ir.AtomicTypeLiteral] = dict() lhs = tuple(_replace_var_names_in_expr(expr, unique_var_name_by_expr_type_literal_name) for expr in initial_exprs) rhs = tuple(_replace_var_names_in_expr(pattern, unique_var_name_by_pattern_type_literal_name) for pattern in patterns) context = [(unique_var_name_by_expr_type_literal_name[local_var_name], _replace_var_names_in_expr(value, unique_var_name_by_expr_type_literal_name)) for local_var_name, value in local_var_definitions.items()] lhs = tuple(_unpack_if_variable(expr, unique_var_names, literal_expr_by_unique_name) for expr in lhs) rhs = tuple(_unpack_if_variable(pattern, unique_var_names, literal_expr_by_unique_name) for pattern in rhs) context = {_unpack_if_variable(var, unique_var_names, literal_expr_by_unique_name): _unpack_if_variable(expr, unique_var_names, literal_expr_by_unique_name) for var, expr in context} unification_strategy = _ExprUnificationStrategy(unique_var_names, set(unique_var_name_by_pattern_type_literal_name.inv.keys()), literal_expr_by_unique_name) try: var_expr_equations, expanded_var_expr_equations = unify([(lhs, rhs)], context, unification_strategy) # type: Tuple[Dict[str, Union[str, ExprOrExprTuple]], Dict[str, List[ir.Expr]]] except UnificationFailedException: if verbose: print('unify(exprs=[%s], local_var_definitions={%s}, patterns=[%s], expr_variables=[%s], pattern_variables=[%s], ...):\nUsing name mappings: %s, %s\nReturning IMPOSSIBLE due to exception: %s' % ( ', '.join(expr_to_cpp_simple(expr) for expr in initial_exprs), ', '.join('%s = %s' % (var, expr_to_cpp_simple(expr)) for var, expr in local_var_definitions.items()), ', '.join(expr_to_cpp_simple(pattern) for pattern in patterns), ', '.join(expr_variable for expr_variable in expr_variables), ', '.join(pattern_variable for pattern_variable in pattern_variables), unique_var_name_by_expr_type_literal_name, unique_var_name_by_pattern_type_literal_name, traceback.format_exc())) return UnificationResult(UnificationResultKind.IMPOSSIBLE) except UnificationAmbiguousException: if verbose: print('unify(exprs=[%s], local_var_definitions={%s}, patterns=[%s], expr_variables=[%s], pattern_variables=[%s], ...):\nUsing name mappings: %s, %s\nReturning POSSIBLE due to exception: %s' % ( ', '.join(expr_to_cpp_simple(expr) for expr in initial_exprs), ', '.join('%s = %s' % (var, expr_to_cpp_simple(expr)) for var, expr in local_var_definitions.items()), ', '.join(expr_to_cpp_simple(pattern) for pattern in patterns), ', '.join(expr_variable for expr_variable in expr_variables), ', '.join(pattern_variable for pattern_variable in pattern_variables), unique_var_name_by_expr_type_literal_name, unique_var_name_by_pattern_type_literal_name, traceback.format_exc())) return UnificationResult(UnificationResultKind.POSSIBLE) except AssertionError as e: if verbose: print('unify(exprs=[%s], local_var_definitions={%s}, patterns=[%s], expr_variables=[%s], pattern_variables=[%s], ...):\nUsing name mappings: %s, %s\nAssertionError' % ( ', '.join(expr_to_cpp_simple(expr) for expr in initial_exprs), ', '.join('%s = %s' % (var, expr_to_cpp_simple(expr)) for var, expr in local_var_definitions.items()), ', '.join(expr_to_cpp_simple(pattern) for pattern in patterns), ', '.join(expr_variable for expr_variable in expr_variables), ', '.join(pattern_variable for pattern_variable in pattern_variables), unique_var_name_by_expr_type_literal_name, unique_var_name_by_pattern_type_literal_name)) raise try: var_expr_equations = canonicalize(var_expr_equations, expanded_var_expr_equations, unification_strategy) except CanonicalizationFailedException: if verbose: print('unify(exprs=[%s], local_var_definitions={%s}, patterns=[%s], expr_variables=[%s], pattern_variables=[%s], ...):\nUsing name mappings: %s, %s\nReturning POSSIBLE due to exception: %s' % ( ', '.join(expr_to_cpp_simple(expr) for expr in initial_exprs), ', '.join('%s = %s' % (var, expr_to_cpp_simple(expr)) for var, expr in local_var_definitions.items()), ', '.join(expr_to_cpp_simple(pattern) for pattern in patterns), ', '.join(expr_variable for expr_variable in expr_variables), ', '.join(pattern_variable for pattern_variable in pattern_variables), unique_var_name_by_expr_type_literal_name, unique_var_name_by_pattern_type_literal_name, traceback.format_exc())) return UnificationResult(UnificationResultKind.POSSIBLE) except AssertionError as e: if verbose: print('unify(exprs=[%s], local_var_definitions={%s}, patterns=[%s], expr_variables=[%s], pattern_variables=[%s], ...):\nUsing name mappings: %s, %s\nvar_expr_equations = %s\nAssertionError' % ( ', '.join(expr_to_cpp_simple(expr) for expr in initial_exprs), ', '.join('%s = %s' % (var, expr_to_cpp_simple(expr)) for var, expr in local_var_definitions.items()), ', '.join(expr_to_cpp_simple(pattern) for pattern in patterns), ', '.join(expr_variable for expr_variable in expr_variables), ', '.join(pattern_variable for pattern_variable in pattern_variables), unique_var_name_by_expr_type_literal_name, unique_var_name_by_pattern_type_literal_name, var_expr_equations)) raise var_expr_equations: List[Tuple[Union[ir.Expr, Tuple[ir.Expr, ...]], Union[Tuple[ExprOrExprTuple, ...], ExprOrExprTuple]]] \ = [(_pack_if_variable(var, literal_expr_by_unique_name), tuple(_pack_if_variable(expr, literal_expr_by_unique_name) for expr in exprs) if isinstance(exprs, tuple) else _pack_if_variable(exprs, literal_expr_by_unique_name)) for var, exprs in var_expr_equations] # At this point all equations should be of the form var=expr, with var a variable from a pattern and expr containing # no vars from patterns. for lhs_var, exprs in var_expr_equations: if isinstance(lhs_var, ir.VariadicTypeExpansion): lhs_var = lhs_var.inner_expr assert isinstance(lhs_var, ir.AtomicTypeLiteral) if lhs_var.cpp_type in unique_var_name_by_pattern_type_literal_name.inv: if isinstance(exprs, tuple): for expr in exprs: for rhs_var in expr.free_vars: assert rhs_var.cpp_type not in unique_var_name_by_pattern_type_literal_name.inv else: for rhs_var in exprs.free_vars: assert rhs_var.cpp_type not in unique_var_name_by_pattern_type_literal_name.inv # We reverse the var renaming done above result_var_expr_equations: List[Tuple[ir.AtomicTypeLiteral, Tuple[ir.Expr, ...]]] = [] result_expanded_var_expr_equations: List[Tuple[ir.VariadicTypeExpansion, Tuple[ir.Expr, ...]]] = [] for var, exprs in var_expr_equations: if isinstance(var, ir.VariadicTypeExpansion): assert isinstance(exprs, tuple) result_expanded_var_expr_equations.append((_replace_var_names_in_expr(var, unique_var_name_by_pattern_type_literal_name.inv), tuple(_replace_var_names_in_expr(expr, unique_var_name_by_expr_type_literal_name.inv) for expr in exprs))) else: assert isinstance(var, ir.AtomicTypeLiteral) result_var_expr_equations.append((_replace_var_names_in_expr(var, unique_var_name_by_pattern_type_literal_name.inv), _replace_var_names_in_expr(exprs, unique_var_name_by_expr_type_literal_name.inv))) for var, exprs in var_expr_equations: for expr in (exprs if isinstance(exprs, tuple) else (exprs,)): if var.expr_type != expr.expr_type: if verbose: print('unify(exprs=[%s], local_var_definitions={%s}, patterns=[%s], expr_variables=[%s], pattern_variables=[%s], ...):\nUsing name mappings: %s, %s\nReturning IMPOSSIBLE due to type mismatch:\n%s\nwith type:\n%s\n=== vs ===\n%s\nwith type:\n%s' % ( ', '.join(expr_to_cpp_simple(expr) for expr in initial_exprs), ', '.join('%s = %s' % (var, expr_to_cpp_simple(expr)) for var, expr in local_var_definitions.items()), ', '.join(expr_to_cpp_simple(pattern) for pattern in patterns), ', '.join(expr_variable for expr_variable in expr_variables), ', '.join(pattern_variable for pattern_variable in pattern_variables), unique_var_name_by_expr_type_literal_name, unique_var_name_by_pattern_type_literal_name, expr_to_cpp_simple(var), str(var.expr_type), expr_to_cpp_simple(expr), str(expr.expr_type))) return UnificationResult(UnificationResultKind.IMPOSSIBLE) for var, _ in result_var_expr_equations: assert isinstance(var, ir.AtomicTypeLiteral) for var, _ in result_expanded_var_expr_equations: assert isinstance(var, ir.VariadicTypeExpansion) and isinstance(var.inner_expr, ir.AtomicTypeLiteral) if verbose: print('unify(exprs=[%s], local_var_definitions={%s}, patterns=[%s], expr_variables=[%s], pattern_variables=[%s], ...):\nUsing name mappings: %s, %s\nReturning CERTAIN with result_var_expr_equations:\n%s\nresult_expanded_var_expr_equations:\n%s' % ( ', '.join(expr_to_cpp_simple(expr) for expr in initial_exprs), ', '.join('%s = %s' % (var, expr_to_cpp_simple(expr)) for var, expr in local_var_definitions.items()), ', '.join(expr_to_cpp_simple(pattern) for pattern in patterns), ', '.join(expr_variable for expr_variable in expr_variables), ', '.join(pattern_variable for pattern_variable in pattern_variables), unique_var_name_by_expr_type_literal_name, unique_var_name_by_pattern_type_literal_name, '\n'.join(expr_to_cpp_simple(var) + ' = [' + ', '.join(expr_to_cpp_simple(expr) for expr in (exprs if isinstance(exprs, tuple) else (exprs,))) + ']' for var, exprs in result_var_expr_equations), '\n'.join(expr_to_cpp_simple(var) + ' = [' + ', '.join(expr_to_cpp_simple(expr) for expr in (exprs if isinstance(exprs, tuple) else (exprs,))) + ']' for var, exprs in result_expanded_var_expr_equations))) return UnificationResult(UnificationResultKind.CERTAIN, tuple(result_var_expr_equations), tuple(result_expanded_var_expr_equations))
def unify_template_instantiation_with_definition(template_instantiation: ir.TemplateInstantiation, local_var_definitions: Mapping[str, ir.Expr], result_elem_name: str, template_defn: ir.TemplateDefn, identifier_generator: Iterator[str], verbose: bool) -> Union[None, Tuple[ir.TemplateSpecialization, Optional[Tuple[Tuple[ ir.AtomicTypeLiteral, Tuple[ ir.Expr, ...]], ...]]], Tuple[ir.TemplateSpecialization, Optional[Tuple[Tuple[ ir.AtomicTypeLiteral, Tuple[ ir.Expr, ...]], ...]]], ir.Expr]: certain_matches, possible_matches = find_matches_in_unification_of_template_instantiation_with_definition(template_instantiation=template_instantiation, local_var_definitions=local_var_definitions, template_defn=template_defn, identifier_generator=identifier_generator, verbose=verbose) possible_matches = [(specialization, None, None) for specialization in possible_matches] if certain_matches or possible_matches: result_exprs: List[ir.Expr] = [] for specialization, _, _ in itertools.chain(certain_matches, possible_matches): if any(isinstance(elem, ir.StaticAssert) for elem in specialization.body): break [result_elem] = [elem for elem in specialization.body if isinstance(elem, (ir.ConstantDef, ir.Typedef)) and elem.name == result_elem_name] assert isinstance(result_elem, (ir.ConstantDef, ir.Typedef)) if any(True for _ in result_elem.expr.free_vars): break result_exprs.append(result_elem.expr) else: # If we didn't break out of the loop, it means that all certain/possible matches would lead to a result # expr with no free vars. first_result_expr = result_exprs[0] for expr in result_exprs: if not is_syntactically_equal(expr, first_result_expr): break else: # If we didn't break out of the loop, it means that all certain/possible matches would lead to *the same* # result expr with no free vars. return first_result_expr if possible_matches: # This must be stricter than all certain matches (if any) so we can't pick one for sure. if verbose: print('No unification found for template %s because there was a result with kind POSSIBLE, so we can\'t inline that.' % template_defn.name) return None if not certain_matches: if verbose: print('No unification found for template %s because there were no matches with kind==CERTAIN, so we can\'t inline that.' % template_defn.name) return None if len(certain_matches) == 1: return certain_matches[0] # We've found multiple specializations that definitely match and aren't stricter than each other. # We can't say for certain which one will be chosen (it probably depends on the specific arguments of the caller # template). if verbose: print('No unification found for template %s because there were multiple specializations with kind==CERTAIN and none of them was stricter than the others. Solutions:\n%s' % ( template_defn.name, '\n'.join('{%s}' % ', '.join('%s = [%s]' % (pattern_var.cpp_type, ', '.join(expr_to_cpp_simple(value) for value in values)) for pattern_var, values in replacements) for specialization, replacements in certain_matches))) return None
def term_to_string(self, term: ir.Expr): return '"' + expr_to_cpp_simple(term) + '"\n' + ir_to_string(term)
def transform_class_member_access( self, class_member_access: ir.ClassMemberAccess): class_member_access = super().transform_class_member_access( class_member_access) assert isinstance(class_member_access, ir.ClassMemberAccess) if (isinstance(class_member_access.expr, ir.TemplateInstantiation) and isinstance(class_member_access.expr.template_expr, ir.AtomicTypeLiteral) and class_member_access.expr.template_expr.cpp_type in self.inlineable_templates_by_name): template_instantiation = class_member_access.expr template_defn_to_inline = self.inlineable_templates_by_name[ template_instantiation.template_expr.cpp_type] else: return class_member_access unification = unify_template_instantiation_with_definition( template_instantiation, self.parent_template_specialization_definitions, class_member_access.member_name, template_defn_to_inline, self.identifier_generator, verbose=ConfigurationKnobs.verbose) if not unification: return class_member_access if isinstance(unification, ir.Expr): self.needs_another_loop = True return _ensure_remains_variadic_if_it_was( original_expr=class_member_access, transformed_expr=unification) specialization, value_by_pattern_variable, value_by_expanded_pattern_variable = unification assert len(value_by_pattern_variable) + len( value_by_expanded_pattern_variable) == len(specialization.args) new_value_by_pattern_variable: Dict[str, ir.Expr] = dict() for var, exprs in value_by_pattern_variable: assert isinstance(var, ir.AtomicTypeLiteral) if isinstance(exprs, list): [exprs] = exprs assert not isinstance(exprs, list) assert not isinstance(exprs, ir.VariadicTypeExpansion) new_value_by_pattern_variable[var.cpp_type] = exprs value_by_pattern_variable = new_value_by_pattern_variable new_value_by_expanded_pattern_variable: Dict[str, List[ir.Expr]] = dict() for var, exprs in value_by_expanded_pattern_variable: if isinstance(var, ir.AtomicTypeLiteral): if not isinstance(exprs, list): exprs = [exprs] for expr in exprs: assert not isinstance(expr, list) new_value_by_expanded_pattern_variable[var.cpp_type] = exprs else: assert isinstance(var, ir.VariadicTypeExpansion) and isinstance( var.expr, ir.AtomicTypeLiteral) assert isinstance(exprs, list) new_value_by_expanded_pattern_variable[ var.expr.cpp_type] = exprs value_by_expanded_pattern_variable = new_value_by_expanded_pattern_variable body = [] result_expr = None for elem in specialization.body: if isinstance( elem, (ir.ConstantDef, ir.Typedef)) and elem.name == class_member_access.member_name: assert not result_expr result_expr = elem.expr else: body.append(elem) assert result_expr new_var_name_by_old_var_name = dict() for elem in body: if isinstance(elem, ir.TemplateDefn): new_var_name_by_old_var_name[elem.name] = next( self.identifier_generator) elif isinstance(elem, ir.ConstantDef): new_var_name_by_old_var_name[elem.name] = next( self.identifier_generator) elif isinstance(elem, ir.Typedef): new_var_name_by_old_var_name[elem.name] = next( self.identifier_generator) elif isinstance(elem, ir.StaticAssert): pass else: raise NotImplementedError('Unexpected elem: ' + elem.__class__.__name__) transformation = NameReplacementTransformation( new_var_name_by_old_var_name) body = transformation.transform_template_body_elems(body) result_expr = transformation.transform_expr(result_expr) try: body = replace_var_with_expr_in_template_body_elements( body, value_by_pattern_variable, value_by_expanded_pattern_variable) for elem in body: if isinstance( elem, (ir.ConstantDef, ir.Typedef)) and compute_non_expanded_variadic_vars( elem.expr): raise VariadicVarReplacementNotPossibleException( 'Needed to replace a non-variadic var with an expr with non-expanded variadic vars in a non-result ConstantDef/Typedef' ) result_expr = replace_var_with_expr_in_expr( result_expr, value_by_pattern_variable, value_by_expanded_pattern_variable) except VariadicVarReplacementNotPossibleException as e: [message] = e.args # We thought we could perform the inlining but we actually can't. if ConfigurationKnobs.verbose: print( 'VariadicVarReplacementNotPossibleException raised for template %s (reason: %s), we can\'t inline that.' % (template_instantiation.template_expr.cpp_type, message)) return class_member_access result_expr = _ensure_remains_variadic_if_it_was( original_expr=class_member_access, transformed_expr=result_expr) if (isinstance(result_expr, ir.ClassMemberAccess) and isinstance(result_expr.expr, ir.TemplateInstantiation) and isinstance(result_expr.expr.template_expr, ir.AtomicTypeLiteral) and result_expr.expr.template_expr.cpp_type.startswith('Select1st') and isinstance(class_member_access.expr, ir.TemplateInstantiation) and isinstance(class_member_access.expr.template_expr, ir.AtomicTypeLiteral) and (class_member_access.expr.template_expr.cpp_type.startswith( 'Select1st') # TODO: make this more precise. This is meant to match the Always*From* templates. or class_member_access.expr.template_expr.cpp_type.startswith( 'Always'))): return class_member_access self.needs_another_loop = True if ConfigurationKnobs.verbose: print('Inlining template defn: %s into %s' % (template_defn_to_inline.name, self.root_template_defn_name or expr_to_cpp_simple(class_member_access))) for elem in body: with transformation.set_writer(self.writer): transformation.transform_template_body_elem(elem) return result_expr