def _add_submodule_to_parent(self):
        """
        Recursively add all submodule to its parent module until Main module.

        Note:
            This function deepcopy the first node of the submodule, and reset its params as parent module.
        """
        depth = self._module_depth_max
        while depth > 0:
            for (scope_path_str, md_struct) in self.module_structs.copy().items():
                if scope_path_str == '[]':
                    continue  # is main module, skip
                if md_struct.scope_depth != depth:
                    continue  # skip all submodules not at current depth
                md_struct_scope = copy.deepcopy(md_struct.identifier)
                md_struct_scope.pop()
                parent_scope = md_struct_scope
                # 1. check if this module has parent module
                parent_md_struct = self.module_structs.get(str(parent_scope))
                if parent_md_struct is not None:
                    # 1A. has parent, directly add md_struct to its parent ModuleStruct.
                    parent_md_struct.add_submodule(md_struct)
                    self.module_structs[str(parent_scope)] = parent_md_struct
                else:
                    # 1B. not has parent, generate a new ModuleStruct
                    # use this submodule to create a parent module
                    parent_md_struct = ModuleStruct(None, init_as_parent=True, parent_base=md_struct)
                    # rewrite parent md struct
                    parent_md_struct.add_submodule(md_struct)
                    self.module_structs[str(parent_scope)] = parent_md_struct
                sub = self.module_structs.pop(scope_path_str)  # remove this submodule from collections
                self._global_context.add_module_struct(sub.pattern_id, sub)
            depth -= 1
    def _form_bottom_submodule(self):
        """Form the basic submodules, which only contains nodes."""
        # Form module map
        curr_scope_path = None
        nd_struct_list_in_submodule = []
        for nd_struct in self.node_structs.values():
            idx = nd_struct.topo_idx
            if curr_scope_path is None:
                curr_scope_path = nd_struct.scope.path
                nd_struct_list_in_submodule.append((idx, nd_struct))
            elif curr_scope_path == nd_struct.scope.path:
                nd_struct_list_in_submodule.append((idx, nd_struct))
            else:  # curr_scope_path changed
                # save this submodule
                if self._module_map.get(str(curr_scope_path)) is not None:
                    self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule
                else:
                    self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule

                # create a new one
                curr_scope_path = nd_struct.scope.path
                nd_struct_list_in_submodule = [(idx, nd_struct)]

        # save last submodule
        if self._module_map.get(str(curr_scope_path)) is not None:
            self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule
        else:
            self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule

        # Form bottom modules' ModuleStruct
        for scope_path_str, nd_struct_list in self._module_map.items():
            self._module_struct_collections[scope_path_str] = ModuleStruct(nd_struct_list)
Exemple #3
0
 def add_shared_weight_to_parent_module(
         shared_weight_name: str, module_to_be_registered: ModuleStruct):
     """Register the shared weight name to module and assign a local var name for it."""
     default_weight_name = f"passthrough_w_{module_to_be_registered.shared_weights_counter}"
     if shared_weight_name not in module_to_be_registered.shared_weights_collection:
         module_to_be_registered.shared_weights_collection[
             shared_weight_name] = default_weight_name
     module_to_be_registered.shared_weights_counter += 1
Exemple #4
0
 def main_model_special_process_inputs(main_model: ModuleStruct):
     """Call in preprocess"""
     # allocate main model construct x
     prec_edges = main_model.external_precursor_nodes_names
     graph_inputs = GlobalContext().onnx_graph_info.get('graph_inputs')
     inputs = dict()
     for edge in graph_inputs:
         if not edge in inputs and edge in prec_edges:
             regular_edge = MatcherHelper.regular_edge_name(edge)
             inputs[edge] = regular_edge
     main_model.inputs_register = inputs