コード例 #1
0
    def _initialize_graph_builder(self, training):
        """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""

        # TODO: PyTorch exporter bug: changes the initializer order in ONNX model
        initializer_names = [name for name, _ in self._flattened_module.named_parameters()]
        initializer_names_to_train = [name for name,
                                      param in self._flattened_module.named_parameters() if param.requires_grad]

        # Build and optimize the full graph
        grad_builder_config = C.OrtModuleGraphBuilderConfiguration()
        grad_builder_config.initializer_names = initializer_names
        grad_builder_config.initializer_names_to_train = initializer_names_to_train
        grad_builder_config.input_names_require_grad = self._input_info.require_grad_names
        grad_builder_config.build_gradient_graph = training
        grad_builder_config.graph_transformer_config = C.GraphTransformerConfiguration()
        grad_builder_config.graph_transformer_config.propagate_cast_ops_level = self._propagate_cast_ops_level
        grad_builder_config.graph_transformer_config.propagate_cast_ops_allow = self._propagate_cast_ops_allow
        grad_builder_config.graph_transformer_config.allow_layer_norm_mod_precision = self._allow_layer_norm_mod_precision
        grad_builder_config.loglevel = {_logger.LogLevel.VERBOSE : C.Severity.VERBOSE,
                                        _logger.LogLevel.INFO : C.Severity.INFO,
                                        _logger.LogLevel.WARNING : C.Severity.WARNING,
                                        _logger.LogLevel.ERROR : C.Severity.ERROR,
                                        _logger.LogLevel.FATAL : C.Severity.FATAL}.get(self._loglevel, C.Severity.WARNING)
        self._graph_builder = C.OrtModuleGraphBuilder()
        self._graph_builder.initialize(self._onnx_model.SerializeToString(), grad_builder_config)
コード例 #2
0
    def _initialize_graph_builder(self, training):
        """Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""

        # TODO: PyTorch exporter bug: changes the initializer order in ONNX model
        initializer_names = [
            name for name, _ in self._flattened_module.named_parameters()
        ]
        initializer_names_to_train = [
            name for name, param in self._flattened_module.named_parameters()
            if param.requires_grad
        ]

        # Build and optimize the full graph
        grad_builder_config = C.OrtModuleGraphBuilderConfiguration()
        grad_builder_config.initializer_names = initializer_names
        grad_builder_config.initializer_names_to_train = initializer_names_to_train
        grad_builder_config.input_names_require_grad = self._input_info.require_grad_names
        grad_builder_config.build_gradient_graph = training
        grad_builder_config.graph_transformer_config = C.GraphTransformerConfiguration(
        )
        grad_builder_config.graph_transformer_config.propagate_cast_ops_level = self._propagate_cast_ops_level
        grad_builder_config.graph_transformer_config.propagate_cast_ops_allow = self._propagate_cast_ops_allow
        self._graph_builder = C.OrtModuleGraphBuilder()
        self._graph_builder.initialize(self._onnx_model.SerializeToString(),
                                       grad_builder_config)