def testFreezing(self): with test_util.use_gpu(): # Save an object-based checkpoint using a frozen saver directory = self.get_temp_dir() prefix = os.path.join(directory, "ckpt") v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) checkpoint = trackable_utils.Checkpoint(v=v) self.evaluate(v.assign(3)) # Create the save counter so assert_consumed doesn't complain about it not # existing in the checkpoint on restore. self.evaluate(checkpoint.save_counter.assign(12)) saver = trackable_utils.frozen_saver(checkpoint) with ops.device("cpu:0"): prefix_tensor = constant_op.constant(prefix) self.evaluate(saver.save(prefix_tensor)) self.evaluate(v.assign(10)) # Use the frozen saver to restore the same object graph self.evaluate(saver.restore(prefix_tensor)) self.assertEqual(3, self.evaluate(v)) # Restore using another frozen saver on an identical object graph del v, checkpoint, saver v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) checkpoint = trackable_utils.Checkpoint(v=v) saver = trackable_utils.frozen_saver(checkpoint) self.evaluate(saver.restore(prefix_tensor)) self.assertEqual(3, self.evaluate(v)) # Restore as an object-based checkpoint del v, checkpoint, saver checkpoint = trackable_utils.Checkpoint() status = checkpoint.restore(prefix) v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) if context.executing_eagerly(): self.assertEqual(12, self.evaluate(checkpoint.save_counter)) self.assertEqual(0, self.evaluate(v)) checkpoint.v = v status.assert_consumed().run_restore_ops() self.assertEqual(3, self.evaluate(v)) self.assertEqual(12, self.evaluate(checkpoint.save_counter))
def model_fn(features, labels, mode): """model_fn for keras Estimator.""" model = _clone_and_build_model(mode=mode, keras_model=keras_model, custom_objects=custom_objects, features=features, labels=labels, optimizer_config=optimizer_config) model_output_names = [] # We need to make sure that the output names of the last layer in the model # is the same for each of the cloned models. This is required for mirrored # strategy when we call regroup. if tf.distribute.has_strategy(): for name in model.output_names: name = re.compile(r'_\d$').sub('', name) model_output_names.append(name) else: model_output_names = model.output_names # Get inputs to EstimatorSpec predictions = dict(zip(model_output_names, model.outputs)) loss = None train_op = None eval_metric_ops = None # Set loss and metric only during train and evaluate. if mode is not ModeKeys.PREDICT: if mode is ModeKeys.TRAIN: model._make_train_function() # pylint: disable=protected-access else: model._make_test_function() # pylint: disable=protected-access loss = model.total_loss eval_metric_ops = _convert_keras_metrics_to_estimator(model) # Set train_op only during train. if mode is ModeKeys.TRAIN: train_op = model.train_function.updates_op if (not model._is_graph_network and hasattr(keras_model, '_original_attributes_cache') and keras_model._original_attributes_cache is not None): # To avoid `model_fn` being destructive for the initial model argument. models.in_place_subclassed_model_state_restoration(keras_model) scaffold = None if save_object_ckpt: model._track_trackable(tf.compat.v1.train.get_global_step(), 'estimator_global_step') # Create saver that maps variable names to object-checkpoint keys. object_graph = graph_view.ObjectGraphView(model) var_list = object_graph.frozen_saveable_objects() saver = tf.compat.v1.train.Saver(var_list=var_list, sharded=True) saver._object_restore_saver = trackable_util.frozen_saver(model) scaffold = tf.compat.v1.train.Scaffold(saver=saver) return model_fn_lib.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs={ _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions) }, scaffold=scaffold)