def __init__(self, saved_model_dir): """Init method for SavedModelLoader. Args: saved_model_dir: A SavedModel directory providing a transform graph. The MetaGraphDef and signature are selected from the SavedModel using keys defined in `../constants.py` ('transform' and 'transform_signature', respectively). """ self._imported = tf.compat.v2.saved_model.load(saved_model_dir) self._load_v2_in_compat = (constants.TRANSFORM_SIGNATURE in self._imported.signatures) if self._load_v2_in_compat: self._wrapped = self._imported.signatures[ constants.TRANSFORM_SIGNATURE] self._func_graph = self._wrapped.graph self._structured_inputs = self._get_input_signature_from_v1_saved_model( saved_model_dir) structured_outputs = self._wrapped.structured_outputs else: self._wrapped = self._imported.transform_fn self._func_graph = self._get_func_graph_from_v2_saved_model( self._wrapped.get_concrete_function().graph) self._structured_inputs = self._get_structured_inputs_from_func_graph( self._func_graph) structured_outputs = tf.nest.pack_sequence_as( self._func_graph.structured_outputs, self._func_graph.outputs, expand_composites=True) self._output_to_inputs_map = ( self._get_output_to_inputs_map(structured_outputs)) saved_transform_io._maybe_register_addon_ops() # pylint: disable=protected-access
def __init__(self, saved_model_dir): """Init method for SavedModelLoader. Args: saved_model_dir: A SavedModel directory providing a transform graph. The MetaGraphDef and signature are selected from the SavedModel using keys defined in `../constants.py` ('transform' and 'transform_signature', respectively). """ if tf.version.VERSION < '2.5': self._imported = load.load_internal(saved_model_dir, loader_cls=_Loader) if isinstance(self._imported, dict): self._imported = self._imported['root'] else: # TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is # dropped. self._imported = tf.compat.v2.saved_model.load(saved_model_dir) self.load_v2_in_compat = (constants.TRANSFORM_SIGNATURE in self._imported.signatures) if self.load_v2_in_compat: self._wrapped = self._imported.signatures[ constants.TRANSFORM_SIGNATURE] self._func_graph = self._wrapped.graph self._structured_inputs = self._get_input_signature_from_v1_saved_model( saved_model_dir) self._structured_outputs = self._wrapped.structured_outputs else: # TODO(b/160550490): Remove local import. from tensorflow_transform import tf2_utils # pylint: disable=g-import-not-at-top # Since `input_signature` was specified when exporting the tf function to # transform_fn is now a ConcreteFunction, but was a tf.function. We need # to handle both to maintain backward compatiblity. If it's a tf.function, # since `input_signature` was specified when exporting the tf function to # `SavedModel`, there should be exactly one concrete function present on # loading the `SavedModel`. if hasattr(self._imported.transform_fn, 'concrete_functions'): concrete_functions = self._imported.transform_fn.concrete_functions assert len(concrete_functions) == 1, concrete_functions self._wrapped = concrete_functions[0] else: self._wrapped = self._imported.transform_fn self._func_graph = self._wrapped.graph self._structured_inputs = ( tf2_utils.get_structured_inputs_from_func_graph( self._func_graph)) self._structured_outputs = tf.nest.pack_sequence_as( self._func_graph.structured_outputs, self._func_graph.outputs, expand_composites=True) self._output_to_inputs_map = (self._get_output_to_inputs_map( self._structured_outputs)) saved_transform_io._maybe_register_addon_ops() # pylint: disable=protected-access
def __init__(self, saved_model_dir: str): """Init method for SavedModelLoader. Args: saved_model_dir: A SavedModel directory providing a transform graph. The MetaGraphDef and signature are selected from the SavedModel using keys defined in `../constants.py` ('transform' and 'transform_signature', respectively). """ # TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is # dropped. imported = tf.compat.v2.saved_model.load(saved_model_dir) load_v2_in_compat = constants.TRANSFORM_SIGNATURE in imported.signatures if load_v2_in_compat: restored_function = imported.signatures[constants.TRANSFORM_SIGNATURE] wrapped, structured_inputs, structured_outputs = ( _restore_from_v1_saved_model(restored_function, saved_model_dir)) else: # transform_fn is now a ConcreteFunction, but was a tf.function. We need # to handle both to maintain backward compatiblity. If it's a tf.function, # since `input_signature` was specified when exporting the tf function to # `SavedModel`, there should be exactly one concrete function present on # loading the `SavedModel`. if hasattr(imported.transform_fn, 'concrete_functions'): concrete_functions = imported.transform_fn.concrete_functions assert len(concrete_functions) == 1, concrete_functions wrapped = concrete_functions[0] else: wrapped = imported.transform_fn func_graph = wrapped.graph structured_inputs = ( tf2_utils.get_structured_inputs_from_func_graph(func_graph)) structured_outputs = tf.nest.pack_sequence_as( func_graph.structured_outputs, func_graph.outputs, expand_composites=True) outputs_to_inputs_map = _get_output_to_inputs_map(structured_outputs) self._initialize(load_v2_in_compat, imported, wrapped, structured_inputs, structured_outputs, outputs_to_inputs_map) saved_transform_io._maybe_register_addon_ops() # pylint: disable=protected-access
def __init__(self, saved_model_dir): """Init method for SavedModelLoader. Args: saved_model_dir: A SavedModel directory providing a transform graph. The MetaGraphDef and signature are selected from the SavedModel using keys defined in `../constants.py` ('transform' and 'transform_signature', respectively). """ # TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is dropped. self._imported = tf.compat.v2.saved_model.load(saved_model_dir) self.load_v2_in_compat = (constants.TRANSFORM_SIGNATURE in self._imported.signatures) if self.load_v2_in_compat: self._wrapped = self._imported.signatures[ constants.TRANSFORM_SIGNATURE] self._func_graph = self._wrapped.graph self._structured_inputs = self._get_input_signature_from_v1_saved_model( saved_model_dir) self._structured_outputs = self._wrapped.structured_outputs else: # Since `input_signature` was specified when exporting the tf function to # `SavedModel`, there should be exactly one concrete function present on # loading the `SavedModel`. concrete_functions = self._imported.transform_fn.concrete_functions assert len(concrete_functions) == 1, concrete_functions self._wrapped = concrete_functions[0] self._func_graph = self._wrapped.graph self._structured_inputs = self._get_structured_inputs_from_func_graph( self._func_graph) self._structured_outputs = tf.nest.pack_sequence_as( self._func_graph.structured_outputs, self._func_graph.outputs, expand_composites=True) self._output_to_inputs_map = (self._get_output_to_inputs_map( self._structured_outputs)) saved_transform_io._maybe_register_addon_ops() # pylint: disable=protected-access