def recalculate_template_instantiation_can_trigger_static_asserts_info( header: ir.Header): if not header.template_defns: return header template_defn_by_name = { template_defn.name: template_defn for template_defn in header.template_defns } template_defn_dependency_graph = compute_template_dependency_graph( header.template_defns, template_defn_by_name) condensed_graph = nx.condensation(template_defn_dependency_graph) assert isinstance(condensed_graph, nx.DiGraph) template_defn_dependency_graph_transitive_closure = nx.transitive_closure( template_defn_dependency_graph) assert isinstance(template_defn_dependency_graph_transitive_closure, nx.DiGraph) # Determine which connected components can trigger static assert errors. condensed_node_can_trigger_static_asserts = defaultdict(lambda: False) for connected_component_index in reversed( list(nx.lexicographical_topological_sort(condensed_graph))): condensed_node = condensed_graph.node[connected_component_index] # If a template defn in this connected component can trigger a static assert, the whole component can. for template_defn_name in condensed_node['members']: if _template_defn_contains_static_assert_stmt( template_defn_by_name[template_defn_name]): condensed_node_can_trigger_static_asserts[ connected_component_index] = True # If a template defn in this connected component references a template defn in a connected component that can # trigger static asserts, this connected component can also trigger them. for called_condensed_node_index in condensed_graph.successors( connected_component_index): if condensed_node_can_trigger_static_asserts[ called_condensed_node_index]: condensed_node_can_trigger_static_asserts[ connected_component_index] = True template_defn_can_trigger_static_asserts = dict() for connected_component_index in condensed_graph: for template_defn_name in condensed_graph.node[ connected_component_index]['members']: template_defn_can_trigger_static_asserts[ template_defn_name] = condensed_node_can_trigger_static_asserts[ connected_component_index] return _apply_template_instantiation_can_trigger_static_asserts_info( header, template_defn_can_trigger_static_asserts)
def template_defns_to_cpp(template_defns: Iterable[ir0.TemplateDefn], writer: ToplevelWriter): template_defn_by_template_name = {elem.name: elem for elem in template_defns} template_dependency_graph = compute_template_dependency_graph(template_defns, template_defn_by_template_name) if template_dependency_graph.number_of_nodes(): template_dependency_graph_condensed = compute_condensation_in_topological_order(template_dependency_graph) else: template_dependency_graph_condensed = [] for connected_component_names in reversed(list(template_dependency_graph_condensed)): connected_component = sorted([template_defn_by_template_name[template_name] for template_name in connected_component_names], key=lambda template_defn: template_defn.name) if len(connected_component) > 1: # There's a dependency loop with >1 templates, we first need to emit all forward decls. for template_defn in connected_component: template_defn_to_cpp_forward_decl(template_defn, enclosing_function_defn_args=[], writer=writer) else: [template_defn] = connected_component if not template_defn.main_definition: # There's no loop here, but this template has only specializations and no main definition, so we need the # forward declaration anyway. template_defn_to_cpp_forward_decl(template_defn, enclosing_function_defn_args=[], writer=writer) template_defns_that_must_be_last = set() for template_defn in connected_component: template_order_dependencies = compute_template_defns_that_must_come_before(template_defn) if any(template_name in connected_component_names for template_name in template_order_dependencies): # This doesn't only need to be before the ones it immediately references, it really needs to be last # since these templates instantiate each other in a cycle. template_defns_that_must_be_last.add(template_defn.name) assert len(template_defns_that_must_be_last) <= 1, 'Found multiple template defns that must appear before each other: ' + ', '.join(template_defns_that_must_be_last) for template_defn in connected_component: if template_defn.name not in template_defns_that_must_be_last: template_defn_to_cpp(template_defn, enclosing_function_defn_args=[], writer=writer) for template_defn in connected_component: if template_defn.name in template_defns_that_must_be_last: specializations = list(template_defn.specializations or tuple()) if template_defn.main_definition: specializations.append(template_defn.main_definition) last_specialization: ir0.TemplateSpecialization = None for specialization in specializations: if any(template_name in connected_component_names for template_name in compute_template_defns_that_must_come_before_specialization(specialization)): assert last_specialization is None, 'Found multiple specializations of ' + template_defn.name + ' that must appear before each other: ' + ', '.join(template_defns_that_must_be_last) last_specialization = specialization else: if template_defn.description: writer.write_toplevel_elem('// %s\n' % template_defn.description) template_specialization_to_cpp(specialization, cxx_name=template_defn.name, enclosing_function_defn_args=[], writer=writer) if last_specialization: if template_defn.description: writer.write_toplevel_elem('// %s\n' % template_defn.description) template_specialization_to_cpp(last_specialization, cxx_name=template_defn.name, enclosing_function_defn_args=[], writer=writer)
def _optimize_header_second_pass( header: ir.Header, identifier_generator: Iterator[str], context_object_file_content: ObjectFileContent): new_template_defns = {elem.name: elem for elem in header.template_defns} template_dependency_graph = compute_template_dependency_graph( header.template_defns, new_template_defns) template_dependency_graph_transitive_closure = nx.transitive_closure( template_dependency_graph) assert isinstance(template_dependency_graph_transitive_closure, nx.DiGraph) optimizations = [ lambda template_defn: perform_template_inlining( template_defn, { other_node for other_node in template_dependency_graph_transitive_closure. successors(template_defn.name) if not template_dependency_graph_transitive_closure.has_edge( other_node, template_defn.name) }, new_template_defns, identifier_generator, context_object_file_content), lambda template_defn: perform_local_optimizations_on_template_defn( template_defn, identifier_generator, inline_template_instantiations_with_multiple_references=False), ] for connected_component in reversed( list( compute_condensation_in_topological_order( template_dependency_graph))): def optimize(template_name: str): new_template_defns[ template_name], needs_another_loop = combine_optimizations( new_template_defns[template_name], optimizations) return None, needs_another_loop _iterate_optimization( None, lambda _: optimize_list( sorted(connected_component, key=lambda node: new_template_defns[node].name), lambda template_name: optimize(template_name)), len(connected_component), lambda _: '\n'.join( template_defn_to_cpp_simple(new_template_defns[template_name], identifier_generator) for template_name in connected_component)) optimizations = [ lambda toplevel_content: perform_template_inlining_on_toplevel_elems( toplevel_content, new_template_defns.keys(), new_template_defns, identifier_generator, context_object_file_content), lambda toplevel_content: perform_local_optimizations_on_toplevel_elems( toplevel_content, identifier_generator, inline_template_instantiations_with_multiple_references=False), ] toplevel_content = _iterate_optimization( header.toplevel_content, lambda toplevel_content: combine_optimizations(toplevel_content, optimizations), len(header.toplevel_content), lambda toplevel_content: '\n'.join( toplevel_elem_to_cpp_simple(elem, identifier_generator) for elem in toplevel_content)) return ir.Header( template_defns=tuple(new_template_defns[template_defn.name] for template_defn in header.template_defns), toplevel_content=toplevel_content, public_names=header.public_names, split_template_name_by_old_name_and_result_element_name=header. split_template_name_by_old_name_and_result_element_name, check_if_error_specializations=header.check_if_error_specializations)