def testFreezing(self):
        with test_util.use_gpu():
            # Save an object-based checkpoint using a frozen saver
            directory = self.get_temp_dir()
            prefix = os.path.join(directory, "ckpt")
            v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
            checkpoint = trackable_utils.Checkpoint(v=v)
            self.evaluate(v.assign(3))
            # Create the save counter so assert_consumed doesn't complain about it not
            # existing in the checkpoint on restore.
            self.evaluate(checkpoint.save_counter.assign(12))
            saver = trackable_utils.frozen_saver(checkpoint)
            with ops.device("cpu:0"):
                prefix_tensor = constant_op.constant(prefix)
            self.evaluate(saver.save(prefix_tensor))
            self.evaluate(v.assign(10))
            # Use the frozen saver to restore the same object graph
            self.evaluate(saver.restore(prefix_tensor))
            self.assertEqual(3, self.evaluate(v))

            # Restore using another frozen saver on an identical object graph
            del v, checkpoint, saver
            v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
            checkpoint = trackable_utils.Checkpoint(v=v)
            saver = trackable_utils.frozen_saver(checkpoint)
            self.evaluate(saver.restore(prefix_tensor))
            self.assertEqual(3, self.evaluate(v))

            # Restore as an object-based checkpoint
            del v, checkpoint, saver
            checkpoint = trackable_utils.Checkpoint()
            status = checkpoint.restore(prefix)
            v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
            if context.executing_eagerly():
                self.assertEqual(12, self.evaluate(checkpoint.save_counter))
                self.assertEqual(0, self.evaluate(v))
            checkpoint.v = v
            status.assert_consumed().run_restore_ops()
            self.assertEqual(3, self.evaluate(v))
            self.assertEqual(12, self.evaluate(checkpoint.save_counter))
示例#2
0
    def model_fn(features, labels, mode):
        """model_fn for keras Estimator."""
        model = _clone_and_build_model(mode=mode,
                                       keras_model=keras_model,
                                       custom_objects=custom_objects,
                                       features=features,
                                       labels=labels,
                                       optimizer_config=optimizer_config)
        model_output_names = []
        # We need to make sure that the output names of the last layer in the model
        # is the same for each of the cloned models. This is required for mirrored
        # strategy when we call regroup.
        if tf.distribute.has_strategy():
            for name in model.output_names:
                name = re.compile(r'_\d$').sub('', name)
                model_output_names.append(name)
        else:
            model_output_names = model.output_names

        # Get inputs to EstimatorSpec
        predictions = dict(zip(model_output_names, model.outputs))

        loss = None
        train_op = None
        eval_metric_ops = None

        # Set loss and metric only during train and evaluate.
        if mode is not ModeKeys.PREDICT:
            if mode is ModeKeys.TRAIN:
                model._make_train_function()  # pylint: disable=protected-access
            else:
                model._make_test_function()  # pylint: disable=protected-access
            loss = model.total_loss

            eval_metric_ops = _convert_keras_metrics_to_estimator(model)

        # Set train_op only during train.
        if mode is ModeKeys.TRAIN:
            train_op = model.train_function.updates_op

        if (not model._is_graph_network
                and hasattr(keras_model, '_original_attributes_cache')
                and keras_model._original_attributes_cache is not None):
            # To avoid `model_fn` being destructive for the initial model argument.
            models.in_place_subclassed_model_state_restoration(keras_model)

        scaffold = None
        if save_object_ckpt:
            model._track_trackable(tf.compat.v1.train.get_global_step(),
                                   'estimator_global_step')
            # Create saver that maps variable names to object-checkpoint keys.
            object_graph = graph_view.ObjectGraphView(model)
            var_list = object_graph.frozen_saveable_objects()
            saver = tf.compat.v1.train.Saver(var_list=var_list, sharded=True)
            saver._object_restore_saver = trackable_util.frozen_saver(model)
            scaffold = tf.compat.v1.train.Scaffold(saver=saver)

        return model_fn_lib.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            export_outputs={
                _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions)
            },
            scaffold=scaffold)