def transform_variadic_type_expansion(self, expr: ir.VariadicTypeExpansion): variadic_vars_to_expand = compute_non_expanded_variadic_vars(expr.expr).keys() previous_variadic_vars_with_expansion_in_progress = self.variadic_vars_with_expansion_in_progress self.variadic_vars_with_expansion_in_progress = previous_variadic_vars_with_expansion_in_progress.union(variadic_vars_to_expand) values_by_variadic_var_to_expand = {var: self.replacement_expr_by_var[var] for var in variadic_vars_to_expand if var in self.replacement_expr_by_var} values_by_expanded_variadic_var_to_expand = {var: self.replacement_expr_by_expanded_var[var] for var in variadic_vars_to_expand if var in self.replacement_expr_by_expanded_var} transformed_exprs = [] if values_by_variadic_var_to_expand or values_by_expanded_variadic_var_to_expand: self._check_variadic_var_replacement(values_by_variadic_var_to_expand, values_by_expanded_variadic_var_to_expand) (num_values_to_expand,) = {(len(values) if isinstance(values, list) else 1) for values in itertools.chain(values_by_variadic_var_to_expand.values(), values_by_expanded_variadic_var_to_expand.values())} for i in range(0, num_values_to_expand): child_replacement_expr_by_var = self.replacement_expr_by_var.copy() child_replacement_expr_by_expanded_var = self.replacement_expr_by_expanded_var.copy() for var, values in values_by_variadic_var_to_expand.items(): if not isinstance(values, list): values = [values] child_replacement_expr_by_var[var] = values[i] for var, values in values_by_expanded_variadic_var_to_expand.items(): if not isinstance(values, list): values = [values] child_replacement_expr_by_expanded_var[var] = values[i] child_transformation = _ReplaceVarWithExprTransformation(child_replacement_expr_by_var, child_replacement_expr_by_expanded_var, self.variadic_vars_with_expansion_in_progress) transformed_expr = child_transformation.transform_expr(expr.expr) for expr1 in (transformed_expr if isinstance(transformed_expr, list) else [transformed_expr]): transformed_exprs.append(expr1) if len(transformed_exprs) == 1: if compute_non_expanded_variadic_vars(transformed_exprs[0]): transformed_exprs = [ir.VariadicTypeExpansion(transformed_exprs[0])] else: if any(compute_non_expanded_variadic_vars(expr) for expr in transformed_exprs): raise VariadicVarReplacementNotPossibleException('Found non-expanded variadic vars after expanding one to multiple elements') else: transformed_expr = self.transform_expr(expr.expr) if isinstance(transformed_expr, list): [transformed_expr] = transformed_expr assert not isinstance(transformed_expr, list) transformed_exprs.append(ir.VariadicTypeExpansion(transformed_expr)) self.variadic_vars_with_expansion_in_progress = previous_variadic_vars_with_expansion_in_progress return transformed_exprs
def _pack_if_variable(var_or_expr: Union[str, ir.Expr, TupleExpansion[ir.Expr]], literal_expr_by_unique_name: Mapping[str, ir.AtomicTypeLiteral]) -> Union[ir.Expr, Tuple[ir.Expr, ...]]: assert not isinstance(var_or_expr, tuple) if isinstance(var_or_expr, str): return literal_expr_by_unique_name[var_or_expr] elif isinstance(var_or_expr, TupleExpansion): return ir.VariadicTypeExpansion(_pack_if_variable(var_or_expr.expr, literal_expr_by_unique_name)) else: return var_or_expr
def transform_variadic_type_expansion(self, expr: ir.VariadicTypeExpansion): expr = self.transform_expr(expr.expr) if is_expr_variadic(expr): return ir.VariadicTypeExpansion(expr) else: # This is not just an optimization, it's an error to have a VariadicTypeExpansion() that doesn't contain # any variadic var refs. return expr
def find_matches_in_unification_of_template_instantiation_with_definition(template_instantiation: ir.TemplateInstantiation, local_var_definitions: Mapping[str, ir.Expr], template_defn: ir.TemplateDefn, identifier_generator: Iterator[str], verbose: bool) -> Tuple[Tuple[Tuple[ir.TemplateSpecialization, Tuple[Tuple[ir.AtomicTypeLiteral, Tuple[ir.Expr, ...]], ...], Tuple[Tuple[ir.AtomicTypeLiteral, Tuple[ir.Expr, ...]], ...]], ...], Tuple[ir.TemplateSpecialization, ...]]: instantiation_vars = {var.cpp_type for var in template_instantiation.free_vars} certain_matches: List[Tuple[ir.TemplateSpecialization, Tuple[Tuple[ir.AtomicTypeLiteral, Tuple[ir.Expr, ...]], ...], ...]] = [] possible_matches: List[ir.TemplateSpecialization] = [] for specialization in template_defn.specializations: result = _unify(template_instantiation.args, local_var_definitions, specialization.patterns, instantiation_vars, {arg.name for arg in specialization.args}, identifier_generator, verbose) if result.kind == UnificationResultKind.CERTAIN: certain_matches.append((specialization, result.value_by_pattern_variable, result.value_by_expanded_pattern_variable)) elif result.kind == UnificationResultKind.POSSIBLE: possible_matches.append(specialization) if template_defn.main_definition and template_defn.main_definition.body: patterns = tuple(ir.VariadicTypeExpansion(ir.AtomicTypeLiteral.for_local(var.name, var.expr_type, is_variadic=var.is_variadic)) if var.is_variadic else ir.AtomicTypeLiteral.for_local(var.name, var.expr_type, is_variadic=var.is_variadic) for var in template_defn.main_definition.args) result = _unify(template_instantiation.args, local_var_definitions, patterns, instantiation_vars, {var.name for var in template_defn.main_definition.args}, identifier_generator, verbose) assert result.kind != UnificationResultKind.IMPOSSIBLE if result.kind == UnificationResultKind.CERTAIN: certain_matches.append((template_defn.main_definition, result.value_by_pattern_variable, result.value_by_expanded_pattern_variable)) else: possible_matches.append(template_defn.main_definition) if not certain_matches: return (), tuple(possible_matches) might_be_best_match = [True for _ in certain_matches] for i, (specialization1, _, _) in enumerate(certain_matches): specialization1_arg_vars = {var.name for var in specialization1.args} for j, (specialization2, _, _) in enumerate(certain_matches): if i != j and might_be_best_match[i] and might_be_best_match[j]: specialization2_arg_vars = {var.name for var in specialization2.args} # Let's see if we can prove that certain_matches[i] is more strict than certain_matches[j] if not specialization1.patterns: might_be_best_match[i] = False continue if not specialization2.patterns: might_be_best_match[j] = False continue result = _unify(specialization1.patterns, dict(), specialization2.patterns, specialization1_arg_vars, specialization2_arg_vars, identifier_generator, verbose) if result.kind == UnificationResultKind.CERTAIN: might_be_best_match[j] = False indexes = tuple(index for index, might_be_best in enumerate(might_be_best_match) if might_be_best) assert indexes return tuple(certain_matches[i] for i in indexes), tuple(possible_matches)