def __init__(self, saved_model_dir, model_dir=None):
        """Initialize a SavedModelEstimator.

    The SavedModelEstimator loads its model function and variable values from
    the graphs defined in the SavedModel. There is no option to pass in
    `RunConfig` or `params` arguments, because the model function graph is
    defined statically in the SavedModel.

    Args:
      saved_model_dir: Directory containing SavedModel protobuf and subfolders.
      model_dir: Directory to save new checkpoints during training.

    Raises:
      NotImplementedError: If a DistributionStrategy is defined in the config.
        Unless the SavedModelEstimator is subclassed, this shouldn't happen.
    """
        checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir)  # pylint: disable=protected-access
        vars_to_warm_start = [
            name for name, _ in checkpoint_utils.list_variables(checkpoint)
        ]
        warm_start_settings = estimator_lib.WarmStartSettings(
            ckpt_to_initialize_from=checkpoint,
            vars_to_warm_start=vars_to_warm_start)

        super(SavedModelEstimator,
              self).__init__(model_fn=self._model_fn_from_saved_model,
                             model_dir=model_dir,
                             warm_start_from=warm_start_settings)
        if self._train_distribution or self._eval_distribution:
            raise NotImplementedError(
                'SavedModelEstimator currently does not support '
                'DistributionStrategy.')
        self.saved_model_dir = saved_model_dir
        self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)
        self._available_modes = self._extract_available_modes()
Example #2
0
    def test_warm_starting_selective_variables(self, fc_impl):
        """Tests selecting variables to warm-start."""
        age = fc_impl.numeric_column('age')
        city = fc_impl.embedding_column(
            fc_impl.categorical_column_with_vocabulary_list(
                'city', vocabulary_list=['Mountain View', 'Palo Alto']),
            dimension=5)

        # Create a DNNLinearCombinedClassifier and train to save a checkpoint.
        dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifierV2(
            linear_feature_columns=[age],
            dnn_feature_columns=[city],
            dnn_hidden_units=[256, 128],
            model_dir=self._ckpt_and_vocab_dir,
            n_classes=4,
            linear_optimizer='SGD',
            dnn_optimizer='SGD')
        dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)

        # Create a second DNNLinearCombinedClassifier, warm-started from the first.
        # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't
        # have accumulator values that change).
        warm_started_dnn_lc_classifier = (
            dnn_linear_combined.DNNLinearCombinedClassifierV2(
                linear_feature_columns=[age],
                dnn_feature_columns=[city],
                dnn_hidden_units=[256, 128],
                n_classes=4,
                linear_optimizer=gradient_descent.GradientDescentOptimizer(
                    learning_rate=0.0),
                dnn_optimizer=gradient_descent_v2.SGD(learning_rate=0.0),
                # The provided regular expression will only warm-start the deep
                # portion of the model.
                warm_start_from=estimator.WarmStartSettings(
                    ckpt_to_initialize_from=dnn_lc_classifier.model_dir,
                    vars_to_warm_start='.*(dnn).*')))

        warm_started_dnn_lc_classifier.train(input_fn=self._input_fn,
                                             max_steps=1)
        for variable_name in warm_started_dnn_lc_classifier.get_variable_names(
        ):
            if 'dnn' in variable_name:
                if 'learning_rate' in variable_name:
                    self.assertAllClose(
                        0.0,
                        warm_started_dnn_lc_classifier.get_variable_value(
                            variable_name))
                else:
                    self.assertAllClose(
                        dnn_lc_classifier.get_variable_value(variable_name),
                        warm_started_dnn_lc_classifier.get_variable_value(
                            variable_name))
            elif 'linear' in variable_name:
                linear_values = warm_started_dnn_lc_classifier.get_variable_value(
                    variable_name)
                # Since they're not warm-started, the linear weights will be
                # zero-initialized.
                self.assertAllClose(np.zeros_like(linear_values),
                                    linear_values)