def __init__(self, saved_model_dir, model_dir=None): """Initialize a SavedModelEstimator. The SavedModelEstimator loads its model function and variable values from the graphs defined in the SavedModel. There is no option to pass in `RunConfig` or `params` arguments, because the model function graph is defined statically in the SavedModel. Args: saved_model_dir: Directory containing SavedModel protobuf and subfolders. model_dir: Directory to save new checkpoints during training. Raises: NotImplementedError: If a DistributionStrategy is defined in the config. Unless the SavedModelEstimator is subclassed, this shouldn't happen. """ checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access vars_to_warm_start = [ name for name, _ in checkpoint_utils.list_variables(checkpoint) ] warm_start_settings = estimator_lib.WarmStartSettings( ckpt_to_initialize_from=checkpoint, vars_to_warm_start=vars_to_warm_start) super(SavedModelEstimator, self).__init__(model_fn=self._model_fn_from_saved_model, model_dir=model_dir, warm_start_from=warm_start_settings) if self._train_distribution or self._eval_distribution: raise NotImplementedError( 'SavedModelEstimator currently does not support ' 'DistributionStrategy.') self.saved_model_dir = saved_model_dir self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir) self._available_modes = self._extract_available_modes()
def test_warm_starting_selective_variables(self, fc_impl): """Tests selecting variables to warm-start.""" age = fc_impl.numeric_column('age') city = fc_impl.embedding_column( fc_impl.categorical_column_with_vocabulary_list( 'city', vocabulary_list=['Mountain View', 'Palo Alto']), dimension=5) # Create a DNNLinearCombinedClassifier and train to save a checkpoint. dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifierV2( linear_feature_columns=[age], dnn_feature_columns=[city], dnn_hidden_units=[256, 128], model_dir=self._ckpt_and_vocab_dir, n_classes=4, linear_optimizer='SGD', dnn_optimizer='SGD') dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1) # Create a second DNNLinearCombinedClassifier, warm-started from the first. # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't # have accumulator values that change). warm_started_dnn_lc_classifier = ( dnn_linear_combined.DNNLinearCombinedClassifierV2( linear_feature_columns=[age], dnn_feature_columns=[city], dnn_hidden_units=[256, 128], n_classes=4, linear_optimizer=gradient_descent.GradientDescentOptimizer( learning_rate=0.0), dnn_optimizer=gradient_descent_v2.SGD(learning_rate=0.0), # The provided regular expression will only warm-start the deep # portion of the model. warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=dnn_lc_classifier.model_dir, vars_to_warm_start='.*(dnn).*'))) warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1) for variable_name in warm_started_dnn_lc_classifier.get_variable_names( ): if 'dnn' in variable_name: if 'learning_rate' in variable_name: self.assertAllClose( 0.0, warm_started_dnn_lc_classifier.get_variable_value( variable_name)) else: self.assertAllClose( dnn_lc_classifier.get_variable_value(variable_name), warm_started_dnn_lc_classifier.get_variable_value( variable_name)) elif 'linear' in variable_name: linear_values = warm_started_dnn_lc_classifier.get_variable_value( variable_name) # Since they're not warm-started, the linear weights will be # zero-initialized. self.assertAllClose(np.zeros_like(linear_values), linear_values)