def _load_and_run_model(self, distribution, saved_dir, predict_dataset, output_name='output_1'): del output_name model = saved_model.load(saved_dir) return self._predict_with_model(distribution, model, predict_dataset)
def _load_and_run_model(self, distribution, saved_dir, predict_dataset, output_name): dist_predict_dataset = distribution.experimental_distribute_dataset( predict_dataset) per_replica_predict_data = next(iter(dist_predict_dataset)) func = saved_model.load(saved_dir) result = distribution.experimental_run_v2( func.signatures[_DEFAULT_FUNCTION_KEY], per_replica_predict_data) return result[output_name]
def _load_and_run_model(self, distribution, saved_dir, predict_dataset, experimental_run_tf_function, output_name='output_1'): del output_name, experimental_run_tf_function model = saved_model.load(saved_dir) return self._predict_with_model(distribution, model, predict_dataset)
def test_savedmodel(self): class MyModule(module.Module): @def_function.function(input_signature=[]) def foo(self): return constant_op.constant([1]) saved_model.save(MyModule(), 'ram://my_module') loaded = saved_model.load('ram://my_module') self.assertAllEqual(loaded.foo(), [1])
def test_no_variable_device_placement(self, model_and_input, distribution, save_in_scope): saved_dir = self.run_test_save_strategy(model_and_input, distribution, save_in_scope) func = saved_model.load(saved_dir) concrete_function = func.signatures[test_base._DEFAULT_FUNCTION_KEY] for f in concrete_function.graph.as_graph_def().library.function: for n in f.node_def: if n.op == 'ReadVariableOp': self.assertEmpty(n.device)
def testSavedModelExport(self): model_dir = os.path.join(self.get_temp_dir(), 'estimator_train_dir') estimator, input_fn = self._make_estimator(model_dir) estimator.train(input_fn, steps=1) # Train to generate a checkpoint. export_dir_base = os.path.join(self.get_temp_dir(), 'estimator_export_dir') export_dir = estimator.export_saved_model(export_dir_base, _serving_input_receiver_fn) # Check the saved model loads and simple inference runs. model = saved_model.load(export_dir) model.signatures['serving_default'](tf.constant([[1.]]))
def test_save_load_io_device(self, model_and_input, distribution): saved_dir = os.path.join(self.get_temp_dir(), 'io_device') with distribution.scope(): model = model_and_input.get_model() x_train, y_train, _ = model_and_input.get_data() batch_size = model_and_input.get_batch_size() self._train_model(model, x_train, y_train, batch_size) call = model.__call__.get_concrete_function(tensor_spec.TensorSpec(None)) save_options = save_options_lib.SaveOptions( experimental_io_device='/job:localhost') saved_model.save(model, saved_dir, signatures=call, options=save_options) load_options = load_options_lib.LoadOptions( experimental_io_device='/job:localhost') # Check that the model can be loaded and training continued without error. with distribution.scope(): loaded_model = saved_model.load(saved_dir, options=load_options) self._train_model(loaded_model, x_train, y_train, batch_size)
def load_and_run_with_saved_model_api(distribution, saved_dir, predict_dataset, output_name): """Loads a saved_model using tf.saved_model API, and runs it.""" func = saved_model.load(saved_dir) if distribution: dist_predict_dataset = distribution.experimental_distribute_dataset( predict_dataset) per_replica_predict_data = next(iter(dist_predict_dataset)) result = distribution.run(func.signatures[_DEFAULT_FUNCTION_KEY], args=(per_replica_predict_data, )) result = result[output_name] # Convert the per_replica value to a list, then concatenate them reduced = distribution.experimental_local_results(result) concat = array_ops.concat(reduced, 0) return concat else: result = func.signatures[_DEFAULT_FUNCTION_KEY](next( iter(predict_dataset))) return result[output_name]
def _load_and_run_model(self, saved_dir, x_predict): func = saved_model.load(saved_dir) return func.signatures[_DEFAULT_FUNCTION_KEY](x_predict)
def __init__(self, session, save_path): with session.graph.as_default(): model = saved_model.load(save_path) session.run(global_variables_initializer()) super().__init__(session, model)