예제 #1
0
    def __init__(self,
                 path_or_bytes,
                 parameters,
                 sess_options=None,
                 providers=None,
                 provider_options=None):
        Session.__init__(self)

        if sess_options:
            self._sess = C.TrainingSession(sess_options)
        else:
            self._sess = C.TrainingSession()

        providers, provider_options = check_and_normalize_provider_args(
            providers, provider_options, C.get_available_providers())

        if isinstance(path_or_bytes, str):
            config_result = self._sess.load_model(path_or_bytes, parameters,
                                                  providers, provider_options)
        elif isinstance(path_or_bytes, bytes):
            config_result = self._sess.read_bytes(path_or_bytes, parameters,
                                                  providers, provider_options)
        else:
            raise TypeError("Unable to load from type '{0}'".format(
                type(path_or_bytes)))

        self.loss_scale_input_name = config_result.loss_scale_input_name

        self._inputs_meta = self._sess.inputs_meta
        self._outputs_meta = self._sess.outputs_meta
예제 #2
0
    def __init__(self,
                 path_or_bytes,
                 parameters,
                 sess_options=None,
                 providers=None,
                 provider_options=None):
        Session.__init__(self)

        if sess_options:
            self._sess = C.TrainingSession(sess_options)
        else:
            self._sess = C.TrainingSession()

        # providers needs to be passed explicitly as of ORT 1.10
        # retain the pre-1.10 behavior by setting to the available providers.
        if providers is None:
            providers = C.get_available_providers()

        providers, provider_options = check_and_normalize_provider_args(
            providers, provider_options, C.get_available_providers())

        if isinstance(path_or_bytes, str):
            config_result = self._sess.load_model(path_or_bytes, parameters,
                                                  providers, provider_options)
        elif isinstance(path_or_bytes, bytes):
            config_result = self._sess.read_bytes(path_or_bytes, parameters,
                                                  providers, provider_options)
        else:
            raise TypeError("Unable to load from type '{0}'".format(
                type(path_or_bytes)))

        self.loss_scale_input_name = config_result.loss_scale_input_name

        self._inputs_meta = self._sess.inputs_meta
        self._outputs_meta = self._sess.outputs_meta
    def __init__(self, path_or_bytes, parameters, sess_options=None):
        if sess_options:
            self._sess = C.TrainingSession(sess_options)
        else:
            self._sess = C.TrainingSession()

        if isinstance(path_or_bytes, str):
            config_result = self._sess.load_model(path_or_bytes, parameters)
        elif isinstance(path_or_bytes, bytes):
            config_result = self._sess.read_bytes(path_or_bytes, parameters)
        else:
            raise TypeError("Unable to load from type '{0}'".format(type(path_or_bytes)))

        self.loss_scale_input_name = config_result.loss_scale_input_name

        self._inputs_meta = self._sess.inputs_meta
        self._outputs_meta = self._sess.outputs_meta

        Session.__init__(self, self._sess)