예제 #1
0
class TrainingAgent(object):
    """
    This is the main class used to run an ORTModule model training.
    """
    def __init__(self,
                 path_or_bytes,
                 fw_feed_names,
                 fw_fetches_names,
                 fw_outputs_device_info,
                 bw_feed_names,
                 bw_fetches_names,
                 bw_outputs_device_info,
                 session_options=None,
                 providers=None,
                 provider_options=None):
        """
        :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
        :param fw_feed_names: Feed names for foward pass.
        :param fw_fetches_names: Fetch names for forward pass.
        :param fw_outputs_device_info: Device info for fetches in forward pass.
        :param bw_feed_names: Feed names for backward pass.
        :param bw_fetches_names: Fetch names for backward pass.
        :param bw_outputs_device_info: Device info for fetches in backward pass.
        :param sess_options: session options
        :param providers: Optional sequence of providers in order of decreasing
            precedence. Values can either be provider names or tuples of
            (provider name, options dict). If not provided, then all available
            providers are used with the default precedence.
        :param provider_options: Optional sequence of options dicts corresponding
            to the providers listed in 'providers'.

        The model type will be inferred unless explicitly set in the SessionOptions.
        To explicitly set:
          so = onnxruntime.SessionOptions()
          so.add_session_config_entry('session.load_model_format', 'ONNX') or
          so.add_session_config_entry('session.load_model_format', 'ORT') or

        A file extension of '.ort' will be inferred as an ORT format model.
        All other filenames are assumed to be ONNX format models.

        'providers' can contain either names or names and options. When any options
        are given in 'providers', 'provider_options' should not be used.

        The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
        means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
        """

        self._inference_session = onnxruntime.InferenceSession(
            path_or_bytes, session_options, providers, provider_options)

        self._training_agent = C_TrainingAgent(self._inference_session._sess,
                                               fw_feed_names, fw_fetches_names,
                                               fw_outputs_device_info,
                                               bw_feed_names, bw_fetches_names,
                                               bw_outputs_device_info)

    def run_forward(self, feeds, fetches, state):
        """
         Compute the forward subgraph for given feeds and fetches.
         :param feeds: Inputs to the graph run.
         :param fetches: Outputs of the graph run.
         :param state: State of the graph that is used for executing partial graph runs.
        """
        self._training_agent.run_forward(feeds, fetches, state)

    def run_backward(self, feeds, fetches, state):
        """
         Compute the backward subgraph for given feeds and fetches.
         :param feeds: Inputs to the graph run.
         :param fetches: Outputs of the graph run.
         :param state: State of the graph that is used for executing partial graph runs.
        """
        self._training_agent.run_backward(feeds, fetches, state)
예제 #2
0
class TrainingAgent(object):
    """
    This is the main class used to run a ORTModule model.
    """
    def __init__(self,
                 path_or_bytes,
                 session_options=None,
                 providers=None,
                 provider_options=None):
        """
        :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
        :param sess_options: session options
        :param providers: Optional sequence of providers in order of decreasing
            precedence. Values can either be provider names or tuples of
            (provider name, options dict). If not provided, then all available
            providers are used with the default precedence.
        :param provider_options: Optional sequence of options dicts corresponding
            to the providers listed in 'providers'.

        The model type will be inferred unless explicitly set in the SessionOptions.
        To explicitly set:
          so = onnxruntime.SessionOptions()
          so.add_session_config_entry('session.load_model_format', 'ONNX') or
          so.add_session_config_entry('session.load_model_format', 'ORT') or

        A file extension of '.ort' will be inferred as an ORT format model.
        All other filenames are assumed to be ONNX format models.

        'providers' can contain either names or names and options. When any options
        are given in 'providers', 'provider_options' should not be used.

        The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
        means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
        """

        self._training_agent = None
        self._inference_session = None

        self.create_training_agent(path_or_bytes, session_options, providers,
                                   provider_options)

    def create_training_agent(self, path_or_bytes, session_options, providers,
                              provider_options):
        self._inference_session = onnxruntime.InferenceSession(
            path_or_bytes, session_options, providers, provider_options)
        self._training_agent = C_TrainingAgent(self._inference_session._sess)

    def io_binding(self):
        "Return an onnxruntime.IOBinding object`."
        return IOBinding(self._inference_session)

    def run_forward(self, iobinding, run_options):
        """
         Compute the forward subgraph until it hits the Yield Op.
         :param iobinding: the iobinding object that has graph inputs/outputs bind.
         :param run_options: See :class:`onnxruntime.RunOptions`.
        """
        ortvalues, run_id = self._training_agent.run_forward(
            iobinding._iobinding, run_options)
        return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id

    def run_backward(self, backward_output_grads, run_id):
        """
         Resume executing the backward subgraph starting from Yield Op.
         :param backward_output_grads: Output gradients for backward.
        """
        self._training_agent.run_backward(
            [ortvalue._ortvalue for ortvalue in backward_output_grads], run_id)