def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None): Session.__init__(self) if sess_options: self._sess = C.TrainingSession(sess_options) else: self._sess = C.TrainingSession() providers, provider_options = check_and_normalize_provider_args( providers, provider_options, C.get_available_providers()) if isinstance(path_or_bytes, str): config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options) elif isinstance(path_or_bytes, bytes): config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options) else: raise TypeError("Unable to load from type '{0}'".format( type(path_or_bytes))) self.loss_scale_input_name = config_result.loss_scale_input_name self._inputs_meta = self._sess.inputs_meta self._outputs_meta = self._sess.outputs_meta
def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None): Session.__init__(self) if sess_options: self._sess = C.TrainingSession(sess_options) else: self._sess = C.TrainingSession() # providers needs to be passed explicitly as of ORT 1.10 # retain the pre-1.10 behavior by setting to the available providers. if providers is None: providers = C.get_available_providers() providers, provider_options = check_and_normalize_provider_args( providers, provider_options, C.get_available_providers()) if isinstance(path_or_bytes, str): config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options) elif isinstance(path_or_bytes, bytes): config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options) else: raise TypeError("Unable to load from type '{0}'".format( type(path_or_bytes))) self.loss_scale_input_name = config_result.loss_scale_input_name self._inputs_meta = self._sess.inputs_meta self._outputs_meta = self._sess.outputs_meta
def __init__(self, path_or_bytes, parameters, sess_options=None): if sess_options: self._sess = C.TrainingSession(sess_options) else: self._sess = C.TrainingSession() if isinstance(path_or_bytes, str): config_result = self._sess.load_model(path_or_bytes, parameters) elif isinstance(path_or_bytes, bytes): config_result = self._sess.read_bytes(path_or_bytes, parameters) else: raise TypeError("Unable to load from type '{0}'".format(type(path_or_bytes))) self.loss_scale_input_name = config_result.loss_scale_input_name self._inputs_meta = self._sess.inputs_meta self._outputs_meta = self._sess.outputs_meta Session.__init__(self, self._sess)