def _initialize_graph_builder(self, training): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" # TODO: PyTorch exporter bug: changes the initializer order in ONNX model initializer_names = [name for name, _ in self._flattened_module.named_parameters()] initializer_names_to_train = [name for name, param in self._flattened_module.named_parameters() if param.requires_grad] # Build and optimize the full graph grad_builder_config = C.OrtModuleGraphBuilderConfiguration() grad_builder_config.initializer_names = initializer_names grad_builder_config.initializer_names_to_train = initializer_names_to_train grad_builder_config.input_names_require_grad = self._input_info.require_grad_names grad_builder_config.build_gradient_graph = training grad_builder_config.graph_transformer_config = C.GraphTransformerConfiguration() grad_builder_config.graph_transformer_config.propagate_cast_ops_level = self._propagate_cast_ops_level grad_builder_config.graph_transformer_config.propagate_cast_ops_allow = self._propagate_cast_ops_allow grad_builder_config.graph_transformer_config.allow_layer_norm_mod_precision = self._allow_layer_norm_mod_precision grad_builder_config.loglevel = {_logger.LogLevel.VERBOSE : C.Severity.VERBOSE, _logger.LogLevel.INFO : C.Severity.INFO, _logger.LogLevel.WARNING : C.Severity.WARNING, _logger.LogLevel.ERROR : C.Severity.ERROR, _logger.LogLevel.FATAL : C.Severity.FATAL}.get(self._loglevel, C.Severity.WARNING) self._graph_builder = C.OrtModuleGraphBuilder() self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)
def _initialize_graph_builder(self, training): """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder""" # TODO: PyTorch exporter bug: changes the initializer order in ONNX model initializer_names = [ name for name, _ in self._flattened_module.named_parameters() ] initializer_names_to_train = [ name for name, param in self._flattened_module.named_parameters() if param.requires_grad ] # Build and optimize the full graph grad_builder_config = C.OrtModuleGraphBuilderConfiguration() grad_builder_config.initializer_names = initializer_names grad_builder_config.initializer_names_to_train = initializer_names_to_train grad_builder_config.input_names_require_grad = self._input_info.require_grad_names grad_builder_config.build_gradient_graph = training grad_builder_config.graph_transformer_config = C.GraphTransformerConfiguration( ) grad_builder_config.graph_transformer_config.propagate_cast_ops_level = self._propagate_cast_ops_level grad_builder_config.graph_transformer_config.propagate_cast_ops_allow = self._propagate_cast_ops_allow self._graph_builder = C.OrtModuleGraphBuilder() self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)