def __init__(self, saved_model_dir, model_dir=None): """Initialize a SavedModelEstimator. The SavedModelEstimator loads its model function and variable values from the graphs defined in the SavedModel. There is no option to pass in `RunConfig` or `params` arguments, because the model function graph is defined statically in the SavedModel. Args: saved_model_dir: Directory containing SavedModel protobuf and subfolders. model_dir: Directory to save new checkpoints during training. Raises: NotImplementedError: If a DistributionStrategy is defined in the config. Unless the SavedModelEstimator is subclassed, this shouldn't happen. """ checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access vars_to_warm_start = [name for name, _ in checkpoint_utils.list_variables(checkpoint)] warm_start_settings = estimator_lib.WarmStartSettings( ckpt_to_initialize_from=checkpoint, vars_to_warm_start=vars_to_warm_start) super(SavedModelEstimator, self).__init__( model_fn=self._model_fn_from_saved_model, model_dir=model_dir, warm_start_from=warm_start_settings) if self._distribution is not None: raise NotImplementedError( 'SavedModelEstimator currently does not support ' 'DistributionStrategy.') self.saved_model_dir = saved_model_dir self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir) self._available_modes = self._extract_available_modes()
def test_warm_starting_selective_variables(self, fc_impl): """Tests selecting variables to warm-start.""" age = fc_impl.numeric_column('age') city = fc_impl.embedding_column( fc_impl.categorical_column_with_vocabulary_list( 'city', vocabulary_list=['Mountain View', 'Palo Alto']), dimension=5) # Create a DNNLinearCombinedClassifier and train to save a checkpoint. dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifier( linear_feature_columns=[age], dnn_feature_columns=[city], dnn_hidden_units=[256, 128], model_dir=self._ckpt_and_vocab_dir, n_classes=4, linear_optimizer='SGD', dnn_optimizer='SGD') dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1) # Create a second DNNLinearCombinedClassifier, warm-started from the first. # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't # have accumulator values that change). warm_started_dnn_lc_classifier = ( dnn_linear_combined.DNNLinearCombinedClassifier( linear_feature_columns=[age], dnn_feature_columns=[city], dnn_hidden_units=[256, 128], n_classes=4, linear_optimizer=gradient_descent.GradientDescentOptimizer( learning_rate=0.0), dnn_optimizer=gradient_descent.GradientDescentOptimizer( learning_rate=0.0), # The provided regular expression will only warm-start the deep # portion of the model. warm_start_from=estimator.WarmStartSettings( ckpt_to_initialize_from=dnn_lc_classifier.model_dir, vars_to_warm_start='.*(dnn).*'))) warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1) for variable_name in warm_started_dnn_lc_classifier.get_variable_names( ): if 'dnn' in variable_name: self.assertAllClose( dnn_lc_classifier.get_variable_value(variable_name), warm_started_dnn_lc_classifier.get_variable_value( variable_name)) elif 'linear' in variable_name: linear_values = warm_started_dnn_lc_classifier.get_variable_value( variable_name) # Since they're not warm-started, the linear weights will be # zero-initialized. self.assertAllClose(np.zeros_like(linear_values), linear_values)
def get_latent_gan_estimator(generator_fn, discriminator_fn, loss_fn, optimizer, params, config, ckpt_dir, warmstart_options=True): """Gets an estimator that passes gradients to the input. This function takes in a generator and adds a trainable z variable that is used as input to this generator_fn. The generator itself is treated as a black box through which gradients can pass through without updating any weights. The result is a trainable way to traverse the GAN latent space. The loss_fn is used to actually train the z variable. The generator_fn and discriminator_fn should be previously trained by the tfgan library (on reload, the variables are expected to follow the tfgan format. It may be possible to use the latent gan estimator with entirely custom GANs that do not use the tfgan library as long as the appropriate variables are wired properly). Args: generator_fn: a function defining a Tensorflow graph for a GAN generator. The weights defined in this graph should already be defined in the given checkpoint location. Should have 'mode' as an argument. discriminator_fn: a function defining a Tensorflow graph for a GAN discriminator. Should have 'mode' as an argument. loss_fn: a function defining a Tensorflow graph for a GAN loss. Takes in a GANModel tuple, features, labels, and add_summaries as inputs. optimizer: a tf.Optimizer or a function that returns a tf.Optimizer with no inputs. params: An object containing the following parameters: - batch_size: an int indicating the size of the training batch. - z_shape: the desired shape of the input z values (not counting batch). - learning_rate: a scalar or function defining a learning rate applied to optimizer. - input_clip: the amount to clip the x training variable by. - add_summaries: whether or not to add summaries. - opt_kwargs: optimizer kwargs. config: tf.RunConfig. Should point model to output dir and should indicate whether to save checkpoints (to avoid saving checkpoints, set save_checkpoints_steps to a number larger than the number of train steps). The model_dir field in the RunConfig should point to a directory WITHOUT any saved checkpoints. ckpt_dir: the directory where the model checkpoints live. The checkpoint is used to warm start the underlying GAN. This should NOT be the same as config.model_dir. warmstart_options: boolean, None, or a WarmStartSettings object. If set to True, uses a default WarmStartSettings object. If set to False or None, does not use warm start. If using a custom WarmStartSettings object, make sure that new variables are properly accounted for when reloading the underlying GAN. Defaults to True. Returns: An estimator spec defining a GAN input training estimator. """ model_fn = _get_latent_gan_model_fn(generator_fn, discriminator_fn, loss_fn, optimizer) if isinstance(warmstart_options, estimator.WarmStartSettings): ws = warmstart_options elif warmstart_options: # Default WarmStart loads all variable names except INPUT_NAME and # OPTIMIZER_NAME. var_regex = '^(?!.*(%s|%s).*)' % (INPUT_NAME, OPTIMIZER_NAME) ws = estimator.WarmStartSettings(ckpt_to_initialize_from=ckpt_dir, vars_to_warm_start=var_regex) else: ws = None if 'opt_kwargs' not in params: params['opt_kwargs'] = {} return estimator.Estimator(model_fn=model_fn, config=config, params=params, warm_start_from=ws)