예제 #1
0
    def __init__(
        self,
        path_or_bytes,
        fw_feed_names,
        fw_outputs_device_info,
        bw_fetches_names,
        bw_outputs_device_info,
        session_options=None,
        providers=None,
        provider_options=None,
        local_rank=None,
    ):
        """
        :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
        :param fw_feed_names: Feed names for foward pass.
        :param fw_outputs_device_info: Device info for fetches in forward pass.
        :param bw_fetches_names: Fetch names for backward pass.
        :param bw_outputs_device_info: Device info for fetches in backward pass.
        :param sess_options: session options
        :param providers: Optional sequence of providers in order of decreasing
            precedence. Values can either be provider names or tuples of
            (provider name, options dict). If not provided, then all available
            providers are used with the default precedence.
        :param provider_options: Optional sequence of options dicts corresponding
            to the providers listed in 'providers'.
        :param local_rank: Optional rank of current device, used for memory profiling only.
            Default rank is 0 if not specified.

        The model type will be inferred unless explicitly set in the SessionOptions.
        To explicitly set:
          so = onnxruntime.SessionOptions()
          so.add_session_config_entry('session.load_model_format', 'ONNX') or
          so.add_session_config_entry('session.load_model_format', 'ORT') or

        A file extension of '.ort' will be inferred as an ORT format model.
        All other filenames are assumed to be ONNX format models.

        'providers' can contain either names or names and options. When any options
        are given in 'providers', 'provider_options' should not be used.

        The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
        means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
        """

        self._inference_session = onnxruntime.InferenceSession(
            path_or_bytes, session_options, providers, provider_options
        )

        self._training_agent = C_TrainingAgent(
            self._inference_session._sess,
            fw_feed_names,
            fw_outputs_device_info,
            bw_fetches_names,
            bw_outputs_device_info,
            local_rank,
        )
예제 #2
0
    def _create_onnx_graphs(self):
        """
        Creates forward and backward ONNX graph.
        The new class has the following attributes:

        * `__doc__`: doc string
        * `__module__`: module name (this file)
        * `_run_options`: see :epkg:`RunOptions`
        * `_sess`: :epkg:`InferenceSession` with the original graph
        * `_sess_eval`: :epkg:`InferenceSession` on the graph
            with weights as inputs
        * `_training_agent`: :epkg:`TrainingAgent`
        * `_cache`: :epkg:`OrtValueCache`
        * `_logger`: logger
        * `_input_names`: input names
        * `_debug`: use debug mode
        * `_grad_input_names`: gradient input names
        * `_output_names`: output names
        * `_weights_to_train`: names of the weights to train

        Training attributes

        * `_bw_fetches_names`: bw_fetches_names,
        * `_fw_outputs_device_info`: fw_outputs_device_info,
        * `_bw_outputs_device_info`: bw_outputs_device_info,
        * `_fw_no_grad_output_device_info`: fw_no_grad_output_device_info,
        * `_graph_info`: graph_info}

        Additional attributes added if *keep_model* is True:

        * `_trained_onnx`: ONNX graph for the gradient
        * `_optimized_pre_grad_model`: evaluation ONNX graph taking
            weights as inputs
        * `_graph_builder`: :epkg:`OrtModuleGraphBuilder`
        """
        logger = self._logger
        if logger is not None:
            logger.info("[OrtGradientForwardBackward] create training onnx")
            logger.info("[OrtGradientForwardBackward] input_names=%r",
                        self.input_names)
            logger.info("[OrtGradientForwardBackward] output_names=%r",
                        self.output_names)
            logger.info("[OrtGradientForwardBackward] weights_to_train=%r",
                        self.weights_to_train)

        builder = OrtModuleGraphBuilder()

        if logger is not None:
            cf = self.graph_builder_config.graph_transformer_config
            cfp = cf.propagate_cast_ops_config
            logger.info("[OrtGradientForwardBackward] "
                        "OrtModuleGraphBuilder.initialize")
            logger.info(
                "[OrtGradientForwardBackward] graph_builder_config=%s",
                OrtGradientForwardBackward._repr_helper_(
                    self.graph_builder_config, indent=4))
            logger.info(
                "[OrtGradientForwardBackward] graph_builder_config."
                "graph_transformer_config=%s",
                OrtGradientForwardBackward._repr_helper_(cf, indent=4))
            logger.info(
                "[OrtGradientForwardBackward] graph_builder_config."
                "graph_transformer_config.propagate_cast_ops_config=%s",
                OrtGradientForwardBackward._repr_helper_(cfp, indent=4))

        builder.initialize(self.onnx_model.SerializeToString(),
                           self.graph_builder_config)

        if logger is not None:
            logger.info(
                "[OrtGradientForwardBackward] OrtModuleGraphBuilder.build")
        builder.build()

        if logger is not None:
            logger.info(
                "[OrtGradientForwardBackward] OrtModuleGraphBuilder.get_model")

        train_onnx_model_serialized = builder.get_model()

        optimized_pre_grad_model = builder.get_inference_optimized_model()
        graph_info = builder.get_graph_info()

        if logger is not None:
            logger.info(
                "[OrtGradientForwardBackward] graph_info=%s",
                OrtGradientForwardBackward._repr_helper_(graph_info, indent=4))
            logger.info("[OrtGradientForwardBackward] create TrainSession")
            logger.info(
                "[OrtGradientForwardBackward] sess_options=%s",
                OrtGradientForwardBackward._repr_helper_(self.sess_options,
                                                         indent=4))
            logger.info("[OrtGradientForwardBackward] providers=%r",
                        self.providers)

        sess = InferenceSession(train_onnx_model_serialized,
                                sess_options=self.sess_options,
                                provider_options=self.provider_options,
                                providers=self.providers)

        if logger is not None:
            logger.info("[OrtGradientForwardBackward] create InferenceSession")

        sess_eval = InferenceSession(optimized_pre_grad_model,
                                     sess_options=self.sess_options,
                                     provider_options=self.provider_options,
                                     providers=self.providers)

        if logger is not None:
            logger.info("[OrtGradientForwardBackward] create training agent")

        grad_input_names = [obj.name for obj in sess.get_inputs()]
        bw_fetches_names = [obj.name for obj in sess.get_outputs()]

        fw_outputs_device_info = [
            OrtDevice(
                OrtGradientForwardBackward._provider_name_to_device_type(i),
                OrtDevice.default_memory(), self.device_index)
            for i in self.providers
        ]
        bw_outputs_device_info = [
            OrtDevice(
                OrtGradientForwardBackward._provider_name_to_device_type(
                    self.providers[0]), OrtDevice.default_memory(),
                self.device_index) for i in bw_fetches_names
        ]
        fw_no_grad_output_device_info = [
            OrtDevice(
                OrtGradientForwardBackward._provider_name_to_device_type(
                    self.providers[0]), OrtDevice.default_memory(),
                self.device_index) for i in self.output_names
        ]

        try:
            # onnxruntime>=1.12
            training_agent = TrainingAgent(sess._sess, grad_input_names,
                                           fw_outputs_device_info,
                                           bw_fetches_names,
                                           bw_outputs_device_info, 0)
        except TypeError:
            # onnxruntime<=1.11
            training_agent = TrainingAgent(sess._sess, grad_input_names,
                                           fw_outputs_device_info,
                                           bw_fetches_names,
                                           bw_outputs_device_info)

        if logger is not None:
            logger.info(
                "[OrtGradientForwardBackward] instantiate dynamic class %r",
                self.class_name)
            logger.info("[OrtGradientForwardBackward] weights_to_train=%r",
                        self.weights_to_train)
            logger.info("[OrtGradientForwardBackward] grad_input_names=%r",
                        grad_input_names)
            logger.info("[OrtGradientForwardBackward] bw_fetches_names=%r",
                        bw_fetches_names)
            logger.info("[OrtGradientForwardBackward] device_index=%r",
                        self.device_index)
        devices = list(fw_outputs_device_info)
        while len(devices) < len(grad_input_names):
            devices.append(devices[-1])

        trained_onnx = onnx.load(BytesIO(train_onnx_model_serialized))
        onnx_loss = onnx.load(BytesIO(optimized_pre_grad_model))
        for i, node in enumerate(trained_onnx.graph.node):
            if node.name == '':
                node.name = "N%d" % i
        for i, node in enumerate(onnx_loss.graph.node):
            if node.name == '':
                node.name = "N%d" % i

        kwargs = {
            '_run_options': self.run_options,
            '_sess': sess,
            '_sess_eval': sess_eval,
            '_training_agent': training_agent,
            '_cache': OrtValueCache(),
            '_logger': logger,
            '_input_names': self.input_names,
            '_grad_input_names': grad_input_names,
            '_output_names': self.output_names,
            '_bw_fetches_names': bw_fetches_names,
            '_fw_outputs_device_info': fw_outputs_device_info,
            '_bw_outputs_device_info': bw_outputs_device_info,
            '_fw_no_grad_output_device_info': fw_no_grad_output_device_info,
            '_weights_to_train': list(sorted(self.weights_to_train)),
            '_graph_info': graph_info,
            #
            '_trained_onnx': trained_onnx,
            '_optimized_pre_grad_model': onnx_loss,
            '_graph_builder': builder,
            '_devices': devices,
            '_debug': self.debug
        }
        graph = kwargs['_trained_onnx'].graph
        kwargs.update({
            '_onx_inp': [o.name for o in graph.input],
            '_onx_out': [o.name for o in graph.output]
        })

        if len(kwargs['_onx_inp']) != len(kwargs['_onx_out']):
            raise RuntimeError(  # pragma: no cover
                "Gradient input and output are inconsistant: "
                "%r != %r" % (kwargs['_onx_inp'], kwargs['_onx_out']))
        return kwargs
예제 #3
0
class TrainingAgent(object):
    """
    This is the main class used to run an ORTModule model training.
    """
    def __init__(self,
                 path_or_bytes,
                 fw_feed_names,
                 fw_fetches_names,
                 fw_outputs_device_info,
                 bw_feed_names,
                 bw_fetches_names,
                 bw_outputs_device_info,
                 session_options=None,
                 providers=None,
                 provider_options=None):
        """
        :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
        :param fw_feed_names: Feed names for foward pass.
        :param fw_fetches_names: Fetch names for forward pass.
        :param fw_outputs_device_info: Device info for fetches in forward pass.
        :param bw_feed_names: Feed names for backward pass.
        :param bw_fetches_names: Fetch names for backward pass.
        :param bw_outputs_device_info: Device info for fetches in backward pass.
        :param sess_options: session options
        :param providers: Optional sequence of providers in order of decreasing
            precedence. Values can either be provider names or tuples of
            (provider name, options dict). If not provided, then all available
            providers are used with the default precedence.
        :param provider_options: Optional sequence of options dicts corresponding
            to the providers listed in 'providers'.

        The model type will be inferred unless explicitly set in the SessionOptions.
        To explicitly set:
          so = onnxruntime.SessionOptions()
          so.add_session_config_entry('session.load_model_format', 'ONNX') or
          so.add_session_config_entry('session.load_model_format', 'ORT') or

        A file extension of '.ort' will be inferred as an ORT format model.
        All other filenames are assumed to be ONNX format models.

        'providers' can contain either names or names and options. When any options
        are given in 'providers', 'provider_options' should not be used.

        The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
        means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
        """

        self._inference_session = onnxruntime.InferenceSession(
            path_or_bytes, session_options, providers, provider_options)

        self._training_agent = C_TrainingAgent(self._inference_session._sess,
                                               fw_feed_names, fw_fetches_names,
                                               fw_outputs_device_info,
                                               bw_feed_names, bw_fetches_names,
                                               bw_outputs_device_info)

    def run_forward(self, feeds, fetches, state):
        """
         Compute the forward subgraph for given feeds and fetches.
         :param feeds: Inputs to the graph run.
         :param fetches: Outputs of the graph run.
         :param state: State of the graph that is used for executing partial graph runs.
        """
        self._training_agent.run_forward(feeds, fetches, state)

    def run_backward(self, feeds, fetches, state):
        """
         Compute the backward subgraph for given feeds and fetches.
         :param feeds: Inputs to the graph run.
         :param fetches: Outputs of the graph run.
         :param state: State of the graph that is used for executing partial graph runs.
        """
        self._training_agent.run_backward(feeds, fetches, state)
예제 #4
0
 def create_training_agent(self, path_or_bytes, session_options, providers,
                           provider_options):
     self._inference_session = onnxruntime.InferenceSession(
         path_or_bytes, session_options, providers, provider_options)
     self._training_agent = C_TrainingAgent(self._inference_session._sess)
예제 #5
0
class TrainingAgent(object):
    """
    This is the main class used to run a ORTModule model.
    """
    def __init__(self,
                 path_or_bytes,
                 session_options=None,
                 providers=None,
                 provider_options=None):
        """
        :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
        :param sess_options: session options
        :param providers: Optional sequence of providers in order of decreasing
            precedence. Values can either be provider names or tuples of
            (provider name, options dict). If not provided, then all available
            providers are used with the default precedence.
        :param provider_options: Optional sequence of options dicts corresponding
            to the providers listed in 'providers'.

        The model type will be inferred unless explicitly set in the SessionOptions.
        To explicitly set:
          so = onnxruntime.SessionOptions()
          so.add_session_config_entry('session.load_model_format', 'ONNX') or
          so.add_session_config_entry('session.load_model_format', 'ORT') or

        A file extension of '.ort' will be inferred as an ORT format model.
        All other filenames are assumed to be ONNX format models.

        'providers' can contain either names or names and options. When any options
        are given in 'providers', 'provider_options' should not be used.

        The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
        means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
        """

        self._training_agent = None
        self._inference_session = None

        self.create_training_agent(path_or_bytes, session_options, providers,
                                   provider_options)

    def create_training_agent(self, path_or_bytes, session_options, providers,
                              provider_options):
        self._inference_session = onnxruntime.InferenceSession(
            path_or_bytes, session_options, providers, provider_options)
        self._training_agent = C_TrainingAgent(self._inference_session._sess)

    def io_binding(self):
        "Return an onnxruntime.IOBinding object`."
        return IOBinding(self._inference_session)

    def run_forward(self, iobinding, run_options):
        """
         Compute the forward subgraph until it hits the Yield Op.
         :param iobinding: the iobinding object that has graph inputs/outputs bind.
         :param run_options: See :class:`onnxruntime.RunOptions`.
        """
        ortvalues, run_id = self._training_agent.run_forward(
            iobinding._iobinding, run_options)
        return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id

    def run_backward(self, backward_output_grads, run_id):
        """
         Resume executing the backward subgraph starting from Yield Op.
         :param backward_output_grads: Output gradients for backward.
        """
        self._training_agent.run_backward(
            [ortvalue._ortvalue for ortvalue in backward_output_grads], run_id)