def graph_based_converter_tf_to_ms(graph_path: str, input_nodes: dict, output_nodes: List[str], output_folder: str, report_folder: str = None, query_result_folder: str = None): """ Tensorflow to MindSpore based on Graph. Args: graph_path (str): Graph file path. input_nodes (dict): Input node(s) of the model. output_nodes (list[str]): Output node(s) of the model. output_folder (str): Output folder. report_folder (str): Report output folder path. query_result_folder (str): Save the optimized graph and its topological order to disk. """ # Close unnecessary log. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes) if query_result_folder: save_intermediate_graph(graph_obj.dataloader, query_result_folder) GlobalContext.release() return graph_obj.build() generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) model_name = _extract_model_name(graph_path) log_console.info("Code saving begins.") code_fragments = generator_inst.generate() save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) log_console.info("Code saving is finished.") # Release global context. GlobalContext.release()
def graph_based_converter_pytorch_to_ms(graph_path: str, input_nodes: dict, output_nodes: List[str], output_folder: str, report_folder: str = None): """ PyTorch to MindSpore based on Graph. Args: graph_path (str): Graph file path. input_nodes (dict): Input node(s) of the model. output_nodes (list[str]): Output node(s) of the model. output_folder (str): Output folder. report_folder (str): Report output folder path. """ graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes) generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) model_name = _extract_model_name(graph_path) code_fragments = generator_inst.generate() save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) # Release global context. GlobalContext.release()
def __init__(self, args): # define attributes here self.global_context_mgr = GlobalContext() self._identifier = None self._fragment = None self._args_translator = None self._parent_module_struct = None self._global_context = GlobalContext() self.topo_idx = None self.onnx_name = None self.graph_node_ref = None self.scope_name = None self.ready_to_generate = False # Defined Scope class self.scope = None # Define attributes used for code generation # key is prec_node_name, value is x; For code line use self.inputs_in_construct_header = OrderedDict() # Matched inputs will can be directly used by code line generation self.matched_inputs = list() # initialize funcs. for arg in args: self.update(arg)
def graph_based_converter_tf_to_ms(graph_path: str, input_nodes: dict, output_nodes: List[str], output_folder: str, report_folder: str = None): """ Tensorflow to MindSpore based on Graph. Args: graph_path (str): Graph file path. input_nodes (dict): Input node(s) of the model. output_nodes (list[str]): Output node(s) of the model. output_folder (str): Output folder. report_folder (str): Report output folder path. """ # Close unnecessary log. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes) generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) model_name = _extract_model_name(graph_path) code_fragments = generator_inst.generate() save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) # Release global context. GlobalContext.release()
def __init__(self): """Init the generator.""" # define MUST have params self._node_struct_collections = OrderedDict() self._module_struct_collections = OrderedDict() self._module_depth_max = 0 self._module_depth_min = 0 # define intermediate var. during conversion self._module_map = OrderedDict() self._global_context = GlobalContext() self._global_context.node_struct_collections = self._node_struct_collections self._repeated_submodules = set()
def __init__(self, main_model: ModuleStruct): super(MatcherLauncher).__init__() self.main_model = main_model self._global_context = GlobalContext() self._graph_inputs = self._global_context.onnx_graph_info.get( "graph_inputs") self._graph_outputs = self._global_context.onnx_graph_info.get( "graph_outputs")
def public_module_shared_weight_statement_generation( public_module: ModuleStruct): """Return the statement of declaration of shared weights in its public module.""" statements = [] for passthrough_w_onnx_name, passthrough_w_var_name in public_module.shared_weights_collection.items( ): parameter_statement = GlobalContext( ).repeated_weights_declaration.get(passthrough_w_onnx_name) declare_statement = f"self.{passthrough_w_var_name} = {parameter_statement}" statements.append(declare_statement) return statements
def main_model_special_process_inputs(main_model: ModuleStruct): """Call in preprocess""" # allocate main model construct x prec_edges = main_model.external_precursor_nodes_names graph_inputs = GlobalContext().onnx_graph_info.get('graph_inputs') inputs = dict() for edge in graph_inputs: if not edge in inputs and edge in prec_edges: regular_edge = MatcherHelper.regular_edge_name(edge) inputs[edge] = regular_edge main_model.inputs_register = inputs
def __init__(self, sid, module_name, nodes, dataloader, merged_modules): global MODULE_NAME_MGR, MODULE_NAME_INDEXING self.sid = sid if module_name in MODULE_NAME_MGR: self.fake_module_name = MODULE_NAME_MGR[module_name] else: self.fake_module_name = f"Module{MODULE_NAME_INDEXING}" MODULE_NAME_INDEXING += 1 MODULE_NAME_MGR[module_name] = self.fake_module_name self.module_name = module_name if module_name not in GlobalContext().known_module_name: GlobalContext().known_module_name[ self.fake_module_name] = module_name self.nodes = { nd: dataloader.nodes_dict.get(nd) or merged_modules.get(nd) for nd in nodes } self.heads, self.inputs, self.tails, self.outputs = get_area_heads_and_tails( self.nodes, dataloader) self.start_rank, self.end_rank = get_area_rank(self.heads, self.tails, dataloader)
def convert_according_to_user_selections(graph_obj, output_folder: str, report_folder: str = None, user_operations: Mapping[str, Dict] = None): """ ONNX to MindSpore based on Graph. Args: graph_obj (OnnxGraph): Onnx graph object. output_folder (str): Output folder. report_folder (str): Report output folder path. user_operations (dict): Record user's operations. """ graph_obj.generate_scope_name(user_operations) graph_obj.build() generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) model_name = _extract_model_name(graph_obj.model_path) log_console.info("Code saving begins.") code_fragments = generator_inst.generate() save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) log_console.info("Code saving is finished.") # Release global context. GlobalContext.release()
def check_node_has_shared_weight(node: NodeStruct): """ Check the node has shared weight and return all of them. Args: node (NodeStruct): NodeStruct instance. Returns: list, a list of shared weight onnx names """ shared_weight_names = [] for shared_weight_name, repeated_node_list in GlobalContext( ).repeated_weights.items(): if node.onnx_name in repeated_node_list: shared_weight_names.append(shared_weight_name) return shared_weight_names
def _add_known_module_name(search_path): """ Add known module name to GlobalContext. Args: search_path (SearchPath): Search path. """ ctx = GlobalContext() if search_path.pattern.known_module_name: ctx.known_module_name[ search_path.pattern. module_name] = search_path.pattern.known_module_name for it in search_path.recursion_path: if it.pattern.known_module_name: ctx.known_module_name[ it.pattern.module_name] = it.pattern.known_module_name return ctx
def __init__(self, onnx_model, model_path: str, input_nodes: dict, output_nodes: list, infer_shape=True): log_console.info("Onnx simplifying begins.") onnx_sim = OnnxSimplify() onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, model_path, input_nodes) log_console.info("Onnx simplifying is finished.") self.model = onnx_model_sim self.model_path = model_path self.graph = onnx_model_sim.graph self.nodes = onnx_model_sim.graph.node self.input_nodes = input_nodes self.output_nodes = output_nodes # args for init self._is_infer_shape = infer_shape self._global_context = GlobalContext() # params parsed in init self.inferred_model = None self._nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE self.tensors_dict = {} # {tensor_name: OnnxTensor} self.value_info_dict = {} # Not contains input and output nodes # Record the weight names used many times. self.repeated_weight = dict() self.node_output_shape_dict = OrderedDict() # {node_name: [int]} # Key is edge of ONNX ir graph, value is the corresponding precursor node. self.output_name_to_node_name = dict() # Define dynamic nodes to be evaluated with onnxruntime self.dynamic_resize_node = list() self.dynamic_reshape_node = list() self.eliminated_nodes = list() # Validate init params self._check_user_provided_info() self.initialize()
def _recursive_form_module(self): """Main routine in generator to build modules from bottom to top.""" # 1. List repeated submodules repeated_submodules = self._list_repeated_submodules() # 2. List reused parameters formal_parameters = self._list_formal_parameters(repeated_submodules) # 3. Build base subdmodules and set in/ext params translation for module_struct in self.module_structs.values(): if module_struct.pattern_id == -1: # is main module continue formal_args = formal_parameters.get(module_struct.pattern_id) module_struct.update_args_translation_list(formal_args) # 4. Form parent modules md_collection_len = len(self.module_structs.keys()) len_changes = True while len_changes: self._add_submodule_to_parent() new_len = len(self.module_structs.keys()) if md_collection_len != new_len: md_collection_len = new_len else: len_changes = False GlobalContext().build_struct_finished = True # 5. Update all translated args from module map self._update_all_modules_args_translator() # 6. Update all nodes and moudles input/output self.build_outputs_connection() self.module_structs.get('[]').allocate_construct_header_x() self.module_structs.get('[]').collect_returns() matcher = MatcherLauncher(self.module_structs.get('[]')) matcher.matching_process() for nd_struct in self.node_structs.values(): if nd_struct.fragment.metadata.get("operation") == "Split": self._split_op_procs(nd_struct)
class NodeStruct: """ Define a node struct which stores all info. to generate statement. Args: args (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. Note: You can pass as many args as possible and the Node Struct will update by arguments order. """ def __init__(self, args): # define attributes here self.global_context_mgr = GlobalContext() self._identifier = None self._fragment = None self._args_translator = None self._parent_module_struct = None self._global_context = GlobalContext() self.topo_idx = None self.onnx_name = None self.graph_node_ref = None self.scope_name = None self.ready_to_generate = False # Defined Scope class self.scope = None # Define attributes used for code generation # key is prec_node_name, value is x; For code line use self.inputs_in_construct_header = OrderedDict() # Matched inputs will can be directly used by code line generation self.matched_inputs = list() # initialize funcs. for arg in args: self.update(arg) def __repr__(self): return str({ "address": hex(id(self)), "idx": self.topo_idx, "identifier": self.identifier }) def ori_topo_idx(self): """Get the original topological index in the onnx graph.""" ori_name = self._fragment.metadata.get('source') self.onnx_name = ori_name return self._global_context.onnx_node_name_to_topo_idx.get(ori_name) def update_var_name(self, idx=None): """ Update the var_name of each node. Args: idx (int): The index of the node in this module. """ def _remove_op_header(op_name): """Remove op header which indicating their sources of op set.""" op_name = op_name.replace('nn.', '') op_name = op_name.replace('P.', '') op_name = op_name.replace('onnx.', '') return op_name if idx is not None: self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(idx)).lower() elif self.topo_idx is not None: self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(self.topo_idx)).lower() else: raise ValueError( "Unable to update var name when topo_idx is None.") self.fragment.default_var['variable_name'] = self.ms_var_name def _update_basics_from_gn(self, gn): """Update basic info from GraphNode.""" self.graph_node_ref = gn self.scope_name = gn.scope_name def _update_from_onnx_gn(self, gn: OnnxGraphNode): """Update basic info from OnnxGraphNode.""" self._update_basics_from_gn(gn) def _update_from_fragment(self, frag: Fragment): """Update info from CodeFragment.""" self._fragment = FragmentHandler(frag) if self.ms_op: idx = self._global_context.latest_node_struct_count self.update_var_name(idx=idx) def _set_scope_from_identifier(self): """Set the Node scope from identifier.""" parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier) self.scope = Scope(parsed_scope) @GeneratorError.check_except( "Generator occurs an error when initializing node's args translator.") def init_args_translator(self, translated_args: list): """ Initialize the ArgsTranslator for each Node. Args: translated_args (list): The list of args should be translated to formal args. """ if not self._fragment: raise ValueError("Initialize argument translator failed.") if self._fragment.converted and self._fragment.default_var[ "args"] and translated_args: self._args_translator = ArgsTranslation( self._fragment.default_var["args"], self.ms_var_name, translated_args) @GeneratorError.check_except( "Generator occurs an error when creating node struct.") def update(self, arg): """ Pass Node info. to generator NodeStruct. Args: arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. """ if isinstance(arg, OnnxGraphNode): self._update_from_onnx_gn(arg) elif isinstance(arg, Fragment): self._update_from_fragment(arg) else: raise TypeError( "NodeStruct received an unsupported initializing argument.") @property def identifier(self): """Return the identifier of the node.""" return self._identifier @identifier.setter def identifier(self, s): """ Set the Node identifier, and update the scope. Args: s (str): The node identifier string. """ self._identifier = s self._set_scope_from_identifier() self.topo_idx = self.ori_topo_idx() self._global_context.onnx_node_name_to_node_struct_map[ self.onnx_name] = self @property def fragment(self): """Return the fragment of the node.""" return self._fragment @fragment.setter def fragment(self, frag): """ Set the Node fragment. Args: frag (NodeFragment): The node identifier string. """ self._fragment = frag @property def graph_node(self): """Return the GraphNode reference.""" return self.graph_node_ref @graph_node.setter def graph_node(self, graphnode): """Set the GraphNode reference.""" self.graph_node_ref = graphnode @property def onnx_node(self): """Return the original onnx node reference.""" return self._global_context.onnx_nodes_collection.get(self.onnx_name) @property def ms_op(self): """Return the operation name in MindSpore.""" return self._fragment.default_var.get('operation') @ms_op.setter def ms_op(self, ms_op_name: str): """Set the operation name in MindSpore.""" self._fragment.default_var['operation'] = ms_op_name @property def ms_var_name(self): """Return the variable name of this Node in the MindSpore script.""" return self._fragment.default_var.get('variable_name') @ms_var_name.setter def ms_var_name(self, ms_var_name: str): """Set the variable name of this Node in the MindSpore script.""" self._fragment.default_var['variable_name'] = ms_var_name @property def ms_opt_var_name(self): """Return the output variable name of current node.""" return self.fragment.fragment.get_outputs_by_idx(0) @property def args_translator(self): """Return the args translator of this Node.""" return self._args_translator @property def precursor_nodes_names(self) -> list: """Return the names of precursor nodes.""" return self.graph_node_ref.precursor_nodes @property def precursor_nodes_structs(self) -> list: """Return the node struct instances of precursor nodes.""" ret = [] precursor_nodes_names = self.precursor_nodes_names for pre_node_name in precursor_nodes_names: nd_struct = self._global_context.onnx_node_name_to_node_struct_map.get( pre_node_name) ret.append(nd_struct) return ret @property def successor_nodes_names(self) -> list: """Return the names of successor nodes.""" return self.graph_node_ref.successor_nodes @property def successor_nodes_structs(self) -> list: """Return the node struct instances of successor nodes.""" ret = [] for pre_node_name in self.successor_nodes_names: nd_struct = self._global_context.onnx_node_name_to_node_struct_map.get( pre_node_name) ret.append(nd_struct) return ret @property def parent_module_struct(self): """Return the parent struct of this node.""" return self._parent_module_struct @parent_module_struct.setter def parent_module_struct(self, ref): self._parent_module_struct = ref @property def outputs_manager(self): """Return the outputs manager instance.""" return self.fragment.outputs_manager @property def outputs_in_construct(self): """Return the outputs var(s) in construct statement.""" return self.fragment.fragment.outputs() @property def inputs_edges_names(self): """Return the inputs edges of this node.""" # Consider moving this process to metadata. ret = [] for edge in self.fragment.metadata.get('inputs'): if not self._global_context.get_onnx_tensor(edge): ret.append(edge) return ret @property def shared_weights(self): """Return the shared weights in this node.""" shared_weight_names = [] for shared_weight_name, repeated_node_list in self._global_context.repeated_weights.items( ): if self.onnx_name in repeated_node_list: shared_weight_names.append(shared_weight_name) return shared_weight_names # Code Generation funcs below def _get_shared_weight_var_names_from_parent(self, onnx_name=None): """ Get shared weight var name in the parent module. Args: onnx_name (str): The onnx name of this weight. Default None. Returns: [List, str], a list of all shared weights the node has or the specific name provided. """ if onnx_name is None: shared_weights_var_name_in_module = [] for shared_w in self.shared_weights: for passthrough_w, passthrough_w_var_name in \ self._parent_module_struct.shared_weights_collection.items(): if shared_w == passthrough_w: shared_weights_var_name_in_module.append( passthrough_w_var_name) return shared_weights_var_name_in_module if isinstance(onnx_name, str): return self._parent_module_struct.shared_weights_collection.get( onnx_name) return [] def code_line_in_init(self): """Initialization line of code in module init block.""" if self._args_translator is not None: self.fragment.default_var['args'] = { **self._args_translator.actual_args, **self._args_translator.formal_args } # create a parameter for shared weight scenario trainable_params = self.fragment.default_var.get("trainable_params") if trainable_params and self.fragment.default_var.get("parameters"): # if trainable params and the mappers accept the param declaration rewritten. for trainable_param_postfix, data_dict in trainable_params.items(): onnx_name = data_dict.get('onnx_name') nparray = data_dict.get('data') try: shape = nparray.shape dtype = nparray.dtype except Exception: raise ValueError("Parameters has inconsistent data type.") # set declare statement declare_statement = self.fragment.fragment.create_parameter( shape, dtype) if onnx_name not in self._global_context.repeated_weights.keys( ): # if the weight is not a shared weight, set to actual declaration. if not self.fragment.default_var["parameters"].get( trainable_param_postfix): self.fragment.default_var["parameters"][ trainable_param_postfix] = declare_statement continue # not a shared weight, skip the rest if onnx_name not in self._global_context.repeated_weights_declaration.keys( ): self._global_context.repeated_weights_declaration[ onnx_name] = declare_statement # set template to mapper parameter rewritten. shared_w_var_in_parent = self._get_shared_weight_var_names_from_parent( onnx_name=onnx_name) # add self for node node under public parent module if self.parent_module_struct.identifier == []: #now only consider declaration in the main model shared_w_var_in_parent = f"self.{shared_w_var_in_parent}" self.fragment.default_var["parameters"][ trainable_param_postfix] = shared_w_var_in_parent def code_line_in_construct(self, inputs=None): """Construct line of code in module construct block. """ left = self.ms_opt_var_name inputs = [] # Bind current node opt_var_name & register to parent self.outputs_manager.bind_opt_var_names(self.fragment.fragment) for base_out in self.outputs_manager.outputs: opt_var = base_out.opt_var_name self.parent_module_struct.internal_outputs_collection[ base_out.onnx_edge_name] = opt_var # Take inputs from parents module for input_edge in self.inputs_edges_names: if input_edge in self.parent_module_struct.inputs_register: inputs.append( self.parent_module_struct.inputs_register.get(input_edge)) elif input_edge in self.parent_module_struct.internal_outputs_collection: inputs.append( self.parent_module_struct.internal_outputs_collection.get( input_edge)) self.fragment.default_var['inputs'] = inputs return left def add_extra_tensor(self): """ Add extra tensor.""" left = "self.{}_w".format(self.ms_var_name) shape = self._fragment.code_setting.op_extra_tensor.shape right = f"Tensor(np.random.uniform(0, 1, {shape}), mindspore.float32)" return left, right # The following functions are specified for multiple in/out support. # and should be called only after generator._recursive_form_modules() def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name): """ Mark the registered external inputs for code generation. Note: This function to be called by its parent (ModuleStruct). Args: header_x (str): The `x` in module construct header. onnx_precursor_node_name (str): The original onnx node name. """ if self.inputs_in_construct_header.get( onnx_precursor_node_name) is not None: raise ValueError( "The input from {} has already registered. Check this node \ {} has duplicate inputs or not.".format( onnx_precursor_node_name, self.identifier)) self.inputs_in_construct_header[onnx_precursor_node_name] = header_x def check_target_node_internal(self, name: str) -> bool: """ Check given node under the same scope. Args: name (str): Can accept both node identifier or original onnx node name. """ target_nd_struct = self._global_context.node_struct_collections.get(name) \ or self._global_context.onnx_node_name_to_node_struct_map.get(name) if target_nd_struct is None and self.topo_idx == 0: # First node always has external input return False if target_nd_struct is None and ( name in self._global_context.onnx_graph_info.get('graph_inputs')): return False if target_nd_struct is None: raise ValueError( "Unable to find the NodeStruct of given target node {}.". format(name)) return target_nd_struct.scope.path == self.scope.path @property def has_successor_node_external(self) -> bool: """Check if any successor_node is in external module.""" for name in self.successor_nodes_names: if not self.check_target_node_internal(name): return False return True @property def precursor_nodes_names_external(self) -> list: """Return a list of external precursor nodes names.""" return [ name for name in self.precursor_nodes_names if not self.check_target_node_internal(name) ] @property def successor_nodes_names_external(self) -> list: """Return a list of external successor nodes names.""" return [ name for name in self.successor_nodes_names if not self.check_target_node_internal(name) ]
def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): """Init. a module by NodeStructs.""" self.pattern_id = -1 # pattern num, -1 as Main module self.pattern_uid = -1 # unique module id for this pattern self.parent_id = None # parent's pattern num self.parent_uid = None # parent's pattern module unique id self.initialized = False self.identifier = None self.module_name = None self.scope_depth = None self.head_nd_struct = None self.head_nd_struct_index = None self.tail_nd_struct = None self.tail_nd_struct_index = None self._node_structs = list() self._module_structs = list() self._fragment = None self._args_translator = None self._parent_module_struct = None # only store original formal args name, not global self._nodes_structs_formal_args_list = list() # define other settings here self._node_args_translation_list = list() self._var_name_mgr = LocalVarNameMgr() self.construct_header_x = OrderedDict( ) # key is header x, value is precursors onnx name self.inputs_in_construct_header = OrderedDict( ) # key is precursors onnx name, value is x in parent construct # key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name) self.outputs_collection = dict() self.matched_inputs = list( ) # Matched inputs will can be directly used by code line generation # key is ext. succ node onnx name, value is local opt_var self.external_successor_local_returns_map = OrderedDict() # Define outputs manager, note this will be assigned later by Generator. self.outputs_manager = None self._global_context = GlobalContext() # Define a dict to store the reference for quick searching self.rapid_reference = dict() # new vars for matcher self.inputs_register = OrderedDict() # reg by sub self.outputs_register = OrderedDict() # reg by sub self.internal_outputs_collection = dict() # reg by sub # new vars for shared weights self.shared_weights_collection = dict() # reg by sub self.shared_weights_counter = 0 # updated by sub if init_as_parent and (parent_base is not None): self.reset_as_parent_passed_in(parent_base) else: # start initialization if not self.initialized: self._init_module(nd_struct_list) else: self._update_module(nd_struct_list) # assign this module reference to node for (_, nd_struct) in nd_struct_list: nd_struct.parent_module_struct = self
class Generator: """The generator controls all routines of code generation.""" def __init__(self): """Init the generator.""" # define MUST have params self._node_struct_collections = OrderedDict() self._module_struct_collections = OrderedDict() self._module_depth_max = 0 self._module_depth_min = 0 # define intermediate var. during conversion self._module_map = OrderedDict() self._global_context = GlobalContext() self._global_context.node_struct_collections = self._node_struct_collections self._repeated_submodules = set() @GeneratorError.check_except("Generator occurs an error when forming base submodules.") def _form_bottom_submodule(self): """Form the basic submodules, which only contains nodes.""" # Form module map curr_scope_path = None nd_struct_list_in_submodule = [] for nd_struct in self.node_structs.values(): idx = nd_struct.topo_idx if curr_scope_path is None: curr_scope_path = nd_struct.scope.path nd_struct_list_in_submodule.append((idx, nd_struct)) elif curr_scope_path == nd_struct.scope.path: nd_struct_list_in_submodule.append((idx, nd_struct)) else: # curr_scope_path changed # save this submodule if self._module_map.get(str(curr_scope_path)) is not None: self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule else: self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule # create a new one curr_scope_path = nd_struct.scope.path nd_struct_list_in_submodule = [(idx, nd_struct)] # save last submodule if self._module_map.get(str(curr_scope_path)) is not None: self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule else: self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule # Form bottom modules' ModuleStruct for scope_path_str, nd_struct_list in self._module_map.items(): self._module_struct_collections[scope_path_str] = ModuleStruct(nd_struct_list) def _list_repeated_submodules(self) -> OrderedDict: """ Return the repeated submodules by its depth and num. For example, "Model/Module3_3" will return {1:(3)} Return: OrderedDict, a dict contains collections of repeated submodules. """ ret = OrderedDict() for depth_control in range(self._module_depth_max, 0, -1): repeated_submodules_at_this_depth = set() for scope_path in self._module_map.keys(): path = Scope.path_str_to_list(scope_path) if len(path) < depth_control: continue # depth control within path length. module_num = path[depth_control - 1][0] repeated_submodules_at_this_depth.add(module_num) ret[depth_control] = repeated_submodules_at_this_depth self._repeated_submodules = ret return ret def _compare_with_base_parameters(self, nd_struct_list): """ Compare the parameter to check if it should be a formal args. Args: nd_struct_list (list): A list of NodeStructs which contains all same nodes in repeated submodules. Return: set, a set of all formal args in this node. """ formal_args = set() if len(nd_struct_list) < 2: return formal_args (_, base_nd_struct) = nd_struct_list[0] for (base_parameter, base_value) in base_nd_struct.fragment.default_var["args"].items(): # for each param for (_, nd_struct) in nd_struct_list[1:]: compared_value = nd_struct.fragment.default_var["args"].get(base_parameter) if compared_value == base_value: continue formal_args.add(base_parameter) break return formal_args @staticmethod def _set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list): """Set the weight with given param postfix to args translation.""" for _, nd_struct in nd_struct_list: nparr = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix).get('data') nd_struct.fragment.default_var["args"][f"{t_param_postfix}_shape"] = nparr.shape nd_struct.fragment.default_var["args"][f"{t_param_postfix}_dtype"] = nparr.dtype init_tensor_template = f"Parameter(Tensor(np.random.uniform(0, 1, "\ f"{{{t_param_postfix}_shape}}).astype(np.{{{t_param_postfix}_dtype}})), "\ f"name=None)" nd_struct.fragment.default_var["parameters"][t_param_postfix] = init_tensor_template def _get_same_trainable_params_onnx_name_from_repeated_nodes(self, t_param_postfix, t_param_data_dict, nd_struct_list: list): """Return all onnx names from the same weights in repeated nodes.""" (_, base_nd_struct) = nd_struct_list[0] t_base_name = t_param_data_dict.get('onnx_name') t_onnx_names = [t_base_name] for (_, nd_struct) in nd_struct_list[1:]: compared_t_param_data_dict = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix) if not compared_t_param_data_dict: raise ValueError(f"Inconsistent trainable params detected for node "\ f"{nd_struct.topo_idx} with base node {base_nd_struct.topo_idx}") compared_t_name = compared_t_param_data_dict.get('onnx_name') t_onnx_names.append(compared_t_name) return t_onnx_names def _partial_shared_weights_in_repeated_submodule_procs(self, nd_struct_list): """ Check each node in repeated submodule to ensure the node has a fully / partial shared weight. Args: nd_struct_list (list): A list of node structs which are same node in repeated modules. """ # Not repeated will skip this function if len(nd_struct_list) < 2: return (_, base_nd_struct) = nd_struct_list[0] shared_w_list = self._global_context.repeated_weights.keys() if not shared_w_list: if base_nd_struct.fragment.default_var.get("parameters"): # set only if has parameters as it requires rewritten. for (t_param_postfix, t_param_data_dict) in \ base_nd_struct.fragment.default_var["trainable_params"].items(): if not isinstance(t_param_data_dict.get('data'), np.ndarray): continue Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list) return for (t_param_postfix, t_param_data_dict) in base_nd_struct.fragment.default_var["trainable_params"].items(): # check each weight if partial shared or fully shared weight if not t_param_data_dict: continue t_onnx_names = self._get_same_trainable_params_onnx_name_from_repeated_nodes(t_param_postfix, t_param_data_dict, nd_struct_list) t_shared_status = [name in shared_w_list for name in t_onnx_names] if True in t_shared_status and False in t_shared_status: # is partial shared, set unshared to fake shared in GlobalContext for idx, (name, status) in enumerate(zip(t_onnx_names, t_shared_status)): if status: # actual shared, do nothing, skip continue node_onnx_name = nd_struct_list[idx][1].onnx_name if not self._global_context.repeated_weights.get(name): self._global_context.repeated_weights[name] = [node_onnx_name] else: self._global_context.repeated_weights[name] += [node_onnx_name] if True not in t_shared_status and base_nd_struct.fragment.default_var.get("parameters"): # if the repeated node is not shared weight and the mapper accept parameters rewritten. if not isinstance(t_param_data_dict.get('data'), np.ndarray): continue Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list) def _list_formal_parameters_in_a_module(self, module_filter_return): """ Find all formal args / params from nodes in a module. Args: module_filter_return (dict): The filtered results from the module_map_filter. Return: list, a list of sets or None indicates all formal args of each node in the module in order. """ formal_params_list = list() transposed = [list(e) for e in zip(*module_filter_return)] for operation in transposed: # use the map filtered result for partial shared weights procs self._partial_shared_weights_in_repeated_submodule_procs(operation) for operation in transposed: formal_parameters = self._compare_with_base_parameters(operation) if formal_parameters: formal_params_list.append(formal_parameters) else: formal_params_list.append(None) return formal_params_list def _list_formal_parameters(self, repeated_submodules) -> dict: """ Return a list of formal parameters in each submodule. Args: repeated_submodules (dict): A dict which contains repeated submodules, acquire this dict from _list_repeated_submodules() Return: OrderedDict, a dict with each submodule's formal args. Example: A return for ResNet50 could be: {0: # submoodule 0 [set('stride', 'in_channels', 'out_channels'), # args of the first node to be set as formal set('num_features'), # args of the second node to be set as formal None, # args of third node to be set as formal, which does not have set('in_channels', 'out_channels'), set('num_features'), None ]}, {3: # submodule 3 [...], {5: # submodule 5 []} # empty returns means no nodes or it's a parent module of submodules. } """ formal_args_in_each_submodule = OrderedDict() checked_module = set() # filter module_map by submodule_num (without depth) for _, module_nums in repeated_submodules.items(): for module_num in module_nums: if module_num in checked_module: # module already checked continue checked_module.add(module_num) map_filtered = self.module_map_filter(module_num=module_num) formal_args_in_this_module = self._list_formal_parameters_in_a_module(map_filtered) formal_args_in_each_submodule[module_num] = formal_args_in_this_module return formal_args_in_each_submodule def _add_submodule_to_parent(self): """ Recursively add all submodule to its parent module until Main module. Note: This function deepcopy the first node of the submodule, and reset its params as parent module. """ depth = self._module_depth_max while depth > 0: for (scope_path_str, md_struct) in self.module_structs.copy().items(): if scope_path_str == '[]': continue # is main module, skip if md_struct.scope_depth != depth: continue # skip all submodules not at current depth md_struct_scope = copy.deepcopy(md_struct.identifier) md_struct_scope.pop() parent_scope = md_struct_scope # 1. check if this module has parent module parent_md_struct = self.module_structs.get(str(parent_scope)) if parent_md_struct is not None: # 1A. has parent, directly add md_struct to its parent ModuleStruct. parent_md_struct.add_submodule(md_struct) self.module_structs[str(parent_scope)] = parent_md_struct else: # 1B. not has parent, generate a new ModuleStruct # use this submodule to create a parent module parent_md_struct = ModuleStruct(None, init_as_parent=True, parent_base=md_struct) # rewrite parent md struct parent_md_struct.add_submodule(md_struct) self.module_structs[str(parent_scope)] = parent_md_struct sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections self._global_context.add_module_struct(sub.pattern_id, sub) depth -= 1 @GeneratorError.check_except("Generator occurs an error when building modules.") def _recursive_form_module(self): """Main routine in generator to build modules from bottom to top.""" # 1. List repeated submodules repeated_submodules = self._list_repeated_submodules() # 2. List reused parameters formal_parameters = self._list_formal_parameters(repeated_submodules) # 3. Build base subdmodules and set in/ext params translation for module_struct in self.module_structs.values(): if module_struct.pattern_id == -1: # is main module continue formal_args = formal_parameters.get(module_struct.pattern_id) module_struct.update_args_translation_list(formal_args) # 4. Form parent modules md_collection_len = len(self.module_structs.keys()) len_changes = True while len_changes: self._add_submodule_to_parent() new_len = len(self.module_structs.keys()) if md_collection_len != new_len: md_collection_len = new_len else: len_changes = False GlobalContext().build_struct_finished = True # 5. Update all translated args from module map self._update_all_modules_args_translator() # 6. Update all nodes and moudles input/output self.build_outputs_connection() self.module_structs.get('[]').allocate_construct_header_x() self.module_structs.get('[]').collect_returns() matcher = MatcherLauncher(self.module_structs.get('[]')) matcher.matching_process() for nd_struct in self.node_structs.values(): if nd_struct.fragment.metadata.get("operation") == "Split": self._split_op_procs(nd_struct) def _shared_weights_processing(self): """Process shared weights.""" # check each node has shared weight for nd_struct in self.node_structs.values(): shared_weights = SharedWeightHelper.check_node_has_shared_weight(nd_struct) if shared_weights: # register each shared weight to public module for shared_w in shared_weights: SharedWeightHelper.register_shared_weight_to_public_parent(nd_struct, shared_w, pub_module_identifier=[]) def _update_all_modules_args_translator(self): """Update all modules' args translators.""" done_submodule = set() for depth in range(self._module_depth_max, 0, -1): # check modules from bottom to top repeated_submodules = copy.deepcopy(self._repeated_submodules) repeated_modules = repeated_submodules.get(depth) if depth is None: continue for pattern_id in repeated_modules: if pattern_id in done_submodule: continue # get all md_structs by same pattern md_list = self._global_context.module_structs.get(pattern_id) self._take_formal_args_from_updated_submodules(md_list) args_translators = self.get_args_translator_from_module_structs_list(md_list) formal_args_list = ArgsTranslationHelper.find_formal_args_in_modules(args_translators) changed_args_translators = self.get_args_translator_from_module_structs_list( md_list, exclude_root_son=True) ArgsTranslationHelper.change_args_to_formal_for_all_translators( formal_args_list, changed_args_translators) done_submodule.add(pattern_id) def _take_formal_args_from_updated_submodules(self, md_list): """ Take formal args from provided modules' nodes and submodules. Args: md_list (list): A list of ModuleStruct. """ if isinstance(md_list, ModuleStruct): md_list = [md_list] for md in md_list: md.args_translator.take_formal_args_from_nodes_and_submodules(md.get_all_sub_translators()) def _update_module_depth_max(self, nd_struct: NodeStruct): """ Update the Generator attribute module_depth_max, which is the maximum depth in the Model. Args: nd_struct (NodeStruct): NodeStruct to be checked its depth. """ depth = nd_struct.scope.depth if isinstance(depth, int): if depth > self._module_depth_max: self._module_depth_max = depth else: raise TypeError("Unable to update global depth due to TypeError in NodeStruct.scope.depth") def add_node(self, node_identifier, node_instance=None, node_fragment=None): """ Add Node information to the generator. Args: node_identifier (str): The unique identifier for the node passed in. node_instance (GraphNode): The GraphNode instance of each node. node_fragment (NodeFragment): The NodeFragment instance of this node passed in. """ if node_identifier is None: raise ValueError("Node Identifier should not be None.") self._global_context.node_fragments[node_identifier] = node_fragment args = [] if node_instance is not None: args.append(node_instance) if node_fragment is not None: args.append(node_fragment) nd_struct = self.node_structs.get(node_identifier) if nd_struct: # NodeStruct already exists nd_struct.update(args) else: # create new Node Struct nd_struct = NodeStruct(args) nd_struct.identifier = node_identifier self._update_module_depth_max(nd_struct) self.node_structs[node_identifier] = nd_struct @property def node_structs(self): """Return all NodeStructs in this model.""" return self._node_struct_collections @property def module_structs(self): """Return all ModuleStructs in this model.""" return self._module_struct_collections def generate_weight_scope_name(self, node): """Generate weight scope name for checkpoint.""" replaced_module_dict = self.node_structs[node].global_context_mgr.known_module_name scope_list = self.node_structs[node].scope.scope_list ms_var_name = self.node_structs[node].ms_var_name weight_scope_name = None for scope in scope_list[1:]: replaced_module = replaced_module_dict.get(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0]) if replaced_module: scope = scope.replace(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], replaced_module) if not weight_scope_name: weight_scope_name = scope else: weight_scope_name = '.'.join((weight_scope_name, scope)) if not weight_scope_name: weight_scope_name = ms_var_name else: weight_scope_name = '.'.join((weight_scope_name, ms_var_name)) return weight_scope_name.lower() def generate_checkpoint(self): """Generate checkpoint.""" mindspore = import_module('mindspore') trainable_weights_dict = dict() weight_map = list() for node_name, node_inst in self.node_structs.items(): if node_inst.fragment.exchange_msg['var_0']['trainable_params']: weights_scope_name = self.generate_weight_scope_name(node_name) onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights'] for idx, (weight_key, weight_value_object) in \ enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()): value_type = weight_value_object.get('type', WeightType.COMMON.value) value_data = weight_value_object['data'] if value_type == WeightType.PARAMETER.value: weight_name = SEPARATOR_BTW_NAME_AND_ID.join((weights_scope_name, weight_key)) else: weight_name = LINK_IN_WEIGHT_NAME.join((weights_scope_name, weight_key)) weight_shape = mindspore.Tensor(value_data).shape data_type = mindspore.Tensor(value_data).dtype trainable_weights_dict[weight_name] = value_data onnx_weight_name = onnx_weight_inst[idx].name onnx_weight_shape = onnx_weight_inst[idx].value.shape onnx_data_type = onnx_weight_inst[idx].value.dtype weight_map.append( { 'converted_weight': { 'name': weight_name, 'shape': weight_shape, 'data_type': str(data_type) }, 'source_weight': { 'name': onnx_weight_name, 'shape': onnx_weight_shape, 'data_type': str(onnx_data_type) } } ) save_obj = list() for weight_name, weight_value in trainable_weights_dict.items(): obj = { 'name': weight_name, 'data': mindspore.Tensor(weight_value) } save_obj.append(obj) return save_obj, weight_map @GeneratorError.check_except("Generator occurs an error when generating code statements.") def generate(self): """ Generate the final script file. Returns: list, a list of each line in script file. """ self._form_bottom_submodule() self._recursive_form_module() self._shared_weights_processing() ckpt_data_list, weight_map = self.generate_checkpoint() CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) outputs = [get_imported_module()] for code_struct in self._global_context.code_structs.values(): for line in code_struct.code_line_list: outputs.append(line) formatted_code, _ = FormatCode("\n".join(outputs), style_config=mindspore_yapf_config()) report_generator = ReportGenerator() report = report_generator.gen_report(formatted_code) del self._global_context return {"model": (formatted_code, report, ckpt_data_list, weight_map)} def get_node_struct(self, node_identifier): """ Get specific NodeStruct by node_identifier. Args: node_identifier (str): The node unique identifier. Return: NodeStruct, the node's NodeStruct. """ return self._node_struct_collections.get(node_identifier, None) def get_module_struct(self, module_identifier): """ Get specific ModuleStruct by module_identifier. Args: module_identifier (str): The module unique identifier. Return: ModuleStruct, the node's ModuleStruct. """ return self._module_struct_collections.get(module_identifier, None) def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False): """ Return a list of args translators which belongs to given module structs. Args: md_list (list): A list of ModuleStruct. exclude_root_son (Bool): If the returned result should include args translator belongs to modules under the Main module. Returns: list, a list of args translators which belongs to given module structs. """ ret = [] for md in md_list: if exclude_root_son and md.parent_id == -1: continue if md.args_translator is not None: ret.append(md.args_translator) return ret def module_map_filter(self, depth=None, module_num=None, uid=None): """ Filter the module map by given conditions. Args: depth (int): Scope depth. module_num (int): The submodule number. uid (int): The unique identifier of a submodule. Return: list, list of NodeStruct list of each submodule. """ ret = list() for scope_path, nd_struct_list in self._module_map.items(): path = Scope.path_str_to_list(scope_path) if not path: # skip main continue # if depth not equals to the indicated depth, skip if depth is not None and len(path) != depth: continue scope_at_depth = path[-1] (m_num, m_uid) = scope_at_depth if uid is not None: if m_num == module_num and m_uid == uid: ret.append(nd_struct_list) else: if m_num == module_num: ret.append(nd_struct_list) return ret def build_outputs_connection(self): """Build all nodes and modules outputs connections.""" for nd_struct in self.node_structs.values(): # for each output in curr node output manager for out in nd_struct.outputs_manager.outputs: # Set the onnx output edge name to this output self._global_context.outputs_storage.add_output(out) self._global_context.outputs_storage.add_onnx_node_name(out.onnx_edge_name, nd_struct.fragment.metadata.get('source')) self._global_context.outputs_storage.add_ms_identifier(out.onnx_edge_name, nd_struct.identifier) # Set input with existing output mapping for idx, inp in enumerate(nd_struct.inputs_edges_names): if inp in self._global_context.outputs_storage.outputs_collections: output_obj = self._global_context.outputs_storage.outputs_collections[inp] output_obj.idx_in_onnx_user[nd_struct.onnx_name] = idx # set ms_user idx, need to modify if not follow onnx order output_obj.idx_in_ms_user[nd_struct.identifier] = idx # set this output to be returned to external output_obj.to_external = not (nd_struct.check_target_node_internal( self._global_context.outputs_storage.onnx_name(inp) )) # collect submodule's and nodes' outputs mgr self._collect_output_mgr() def _collect_output_mgr(self, module=None): """ Collect the outputs manager from nodes and submodules the current module has. Args: module (ModuleStruct): The module struct collecting its nodes and submodules. """ root_module = module or self.get_module_struct('[]') output_mgr_list = list() for struct in root_module.get_generate_order(): if isinstance(struct, tuple): # index 1 is the NodeStruct while 0 is topological index. struct = struct[1] if isinstance(struct, ModuleStruct) and struct.outputs_manager is None: self._collect_output_mgr(module=struct) for out in struct.outputs_manager.outputs: if Generator.check_output_need_to_external(root_module, out): output_mgr_list.append(out.deepcopy()) root_module.outputs_manager = ModuleOutputManager(root_module.identifier, base_out=output_mgr_list) root_module.outputs_manager.assign_opt_var_name_to_each_output(root_module.ms_opt_var_name) @staticmethod def check_output_need_to_external(root_module: ModuleStruct, checked_output: BaseOutput): """ Check the output still need to be returned to module external. Args: root_module (ModuleStruct): The Module that the output to be determined. checked_output (BaseOutput): The output to be checked whether returned by the Module. Returns: bool, True if the output need to be returned to the module external. """ for user in checked_output.onnx_user: if user in root_module.external_successor_nodes_names: return True return False def _split_op_procs(self, split_struct: NodeStruct): """ Support for Split operation multiple outputs. Args: split_struct (NodeStruct): The NodeStruct of the Split op. """ for successor in split_struct.successor_nodes_structs: # 1. target user is internal if split_struct.check_target_node_internal(successor.identifier): idx = self._get_correct_input_idx_from_split(split_struct, successor) if idx is None: raise ValueError("The Split OP should not has empty output.") correct_input = split_struct.fragment.fragment.get_outputs_by_idx(0, idx) to_be_replaced = None for inp in successor.matched_inputs: if "split" in inp: to_be_replaced = inp break if to_be_replaced is not None: successor.matched_inputs = replace_string_in_list(successor.matched_inputs, to_be_replaced, correct_input) # 2. target user is external else: public_parent = self._get_public_parent_module(split_struct, successor) to_be_modified_md = self._get_submodule_has_out_user_under_public_parent(public_parent, successor) idx = self._get_correct_input_idx_from_split(split_struct, successor) if idx is None: raise ValueError("The Split OP should not has empty output.") if to_be_modified_md is None: raise ValueError("Unable to locate the submodule to be modified for Split output matching.") correct_input = split_struct.fragment.fragment.get_outputs_by_idx(0, idx) to_be_replaced = None for inp in to_be_modified_md.matched_inputs: if "split" in inp: to_be_replaced = inp break if to_be_replaced is not None: to_be_modified_md.matched_inputs = replace_string_in_list(to_be_modified_md.matched_inputs, to_be_replaced, correct_input) def _get_correct_input_idx_from_split(self, split_struct: NodeStruct, split_out_user: NodeStruct): """Return the index of the split output the user used.""" split_struct_out_edges = split_struct.fragment.metadata.get("outputs") for idx, out in enumerate(split_struct_out_edges): if out in split_out_user.fragment.metadata.get("inputs"): return idx return None def _get_public_parent_module(self, node_a: NodeStruct, node_b: NodeStruct): """Return the public parent module of both Node A and Node B.""" find = False b_onnx_name = node_b.onnx_name tmp = node_a while not find: parent_struct = tmp.parent_module_struct if b_onnx_name in parent_struct.onnx_names: find = True tmp = parent_struct return tmp def _get_submodule_has_out_user_under_public_parent(self, public_module: ModuleStruct, node_out_user: NodeStruct): """Return the ModuleStruct which under the public module and contains the NodeStruct which provided.""" for module_struct in public_module.module_structs: if node_out_user.onnx_name in module_struct.onnx_names: return module_struct return None
def _generate_from_module_struct(self, md_struct, repeated_submodules): """ Generate the code of current Module Struct, collecting data from submodules. Args: md_struct (ModuleStruct): The ModuleStruct which generates codes. repeated_submodules (dict): The dict contains all submodules which use repeatedly. Can get this dict from generator. """ # Define Module header code line below class_name = md_struct.class_name # define a class declaration self.new_line = f"class {class_name}(nn.Cell):" # Get all formal args from nodes module_def_args = ['self'] if md_struct.args_translator.actual_args: for actual in md_struct.args_translator.actual_args.keys(): module_def_args.append(actual) if md_struct.args_translator.formal_args: for formal in md_struct.args_translator.formal_args.keys(): module_def_args.append(formal) # set passthrough weights for shared weights, no need for main model if md_struct.identifier != []: module_def_args = SharedWeightHelper.add_shared_weights_in_init_statement(md_struct, module_def_args) # For code line in init & construct blocks init_lines = list() cons_lines = list() for (_, struct) in md_struct.get_generate_order(): if isinstance(struct, NodeStruct): # Generate code line for Node. _ = struct.code_line_in_init() _ = struct.code_line_in_construct() init_str, cons_str = struct.fragment.fragment() init_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in init_str] cons_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in cons_str] code_line_construct = cons_str init_lines += init_str cons_lines += cons_str else: # is ModuleStruct # check if this instance generated CodeStruct if GlobalContext().code_structs.get(struct.pattern_id) is None: CodeStruct(struct, repeated_submodules) code_line_init = struct.code_line_in_init() code_line_construct = struct.code_line_in_construct(inputs=struct.matched_inputs) init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}") cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}") # define header of init block self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):" self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()" #add shared weights declaration in init code part if md_struct.identifier == []: passthrough_w_declaration = SharedWeightHelper.public_module_shared_weight_statement_generation(md_struct) for s in passthrough_w_declaration: self.new_line = f"{SECOND_LEVEL_INDENT}{s}" # add init code lines to code line list. self.code_line_list += init_lines self.new_line = f"{NEW_LINE * 2}" # define header of construct block inputs = ['self'] + list(md_struct.inputs_register.values()) self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):" # add construct code lines to code line list. self.code_line_list += cons_lines # define returns returns = [] # take opt_var_name to return_list for output_edge in md_struct.outputs_register.keys(): opt_var_name = md_struct.internal_outputs_collection.get(output_edge) if opt_var_name is None: raise ValueError(f"Module {md_struct.identifier} has an output {output_edge} has unknown opt_var_name.") returns.append(opt_var_name) self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" self.new_line = f"{NEW_LINE * 2}" GlobalContext().code_structs[md_struct.pattern_id] = self