예제 #1
0
  def __init__(self, saved_model_dir, model_dir=None):
    """Initialize a SavedModelEstimator.

    The SavedModelEstimator loads its model function and variable values from
    the graphs defined in the SavedModel. There is no option to pass in
    `RunConfig` or `params` arguments, because the model function graph is
    defined statically in the SavedModel.

    Args:
      saved_model_dir: Directory containing SavedModel protobuf and subfolders.
      model_dir: Directory to save new checkpoints during training.

    Raises:
      NotImplementedError: If a DistributionStrategy is defined in the config.
        Unless the SavedModelEstimator is subclassed, this shouldn't happen.
    """
    checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir)  # pylint: disable=protected-access
    vars_to_warm_start = [name for name, _ in
                          checkpoint_utils.list_variables(checkpoint)]
    warm_start_settings = estimator_lib.WarmStartSettings(
        ckpt_to_initialize_from=checkpoint,
        vars_to_warm_start=vars_to_warm_start)

    super(SavedModelEstimator, self).__init__(
        model_fn=self._model_fn_from_saved_model, model_dir=model_dir,
        warm_start_from=warm_start_settings)
    if self._distribution is not None:
      raise NotImplementedError(
          'SavedModelEstimator currently does not support '
          'DistributionStrategy.')
    self.saved_model_dir = saved_model_dir
    self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)
    self._available_modes = self._extract_available_modes()
    def test_warm_starting_selective_variables(self, fc_impl):
        """Tests selecting variables to warm-start."""
        age = fc_impl.numeric_column('age')
        city = fc_impl.embedding_column(
            fc_impl.categorical_column_with_vocabulary_list(
                'city', vocabulary_list=['Mountain View', 'Palo Alto']),
            dimension=5)

        # Create a DNNLinearCombinedClassifier and train to save a checkpoint.
        dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
            linear_feature_columns=[age],
            dnn_feature_columns=[city],
            dnn_hidden_units=[256, 128],
            model_dir=self._ckpt_and_vocab_dir,
            n_classes=4,
            linear_optimizer='SGD',
            dnn_optimizer='SGD')
        dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)

        # Create a second DNNLinearCombinedClassifier, warm-started from the first.
        # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't
        # have accumulator values that change).
        warm_started_dnn_lc_classifier = (
            dnn_linear_combined.DNNLinearCombinedClassifier(
                linear_feature_columns=[age],
                dnn_feature_columns=[city],
                dnn_hidden_units=[256, 128],
                n_classes=4,
                linear_optimizer=gradient_descent.GradientDescentOptimizer(
                    learning_rate=0.0),
                dnn_optimizer=gradient_descent.GradientDescentOptimizer(
                    learning_rate=0.0),
                # The provided regular expression will only warm-start the deep
                # portion of the model.
                warm_start_from=estimator.WarmStartSettings(
                    ckpt_to_initialize_from=dnn_lc_classifier.model_dir,
                    vars_to_warm_start='.*(dnn).*')))

        warm_started_dnn_lc_classifier.train(input_fn=self._input_fn,
                                             max_steps=1)
        for variable_name in warm_started_dnn_lc_classifier.get_variable_names(
        ):
            if 'dnn' in variable_name:
                self.assertAllClose(
                    dnn_lc_classifier.get_variable_value(variable_name),
                    warm_started_dnn_lc_classifier.get_variable_value(
                        variable_name))
            elif 'linear' in variable_name:
                linear_values = warm_started_dnn_lc_classifier.get_variable_value(
                    variable_name)
                # Since they're not warm-started, the linear weights will be
                # zero-initialized.
                self.assertAllClose(np.zeros_like(linear_values),
                                    linear_values)
예제 #3
0
def get_latent_gan_estimator(generator_fn, discriminator_fn, loss_fn,
                             optimizer, params, config, ckpt_dir,
                             warmstart_options=True):
  """Gets an estimator that passes gradients to the input.

  This function takes in a generator and adds a trainable z variable that is
  used as input to this generator_fn. The generator itself is treated as a black
  box through which gradients can pass through without updating any weights. The
  result is a trainable way to traverse the GAN latent space. The loss_fn is
  used to actually train the z variable. The generator_fn and discriminator_fn
  should be previously trained by the tfgan library (on reload, the variables
  are expected to follow the tfgan format. It may be possible to use the
  latent gan estimator with entirely custom GANs that do not use the tfgan
  library as long as the appropriate variables are wired properly).

  Args:
    generator_fn: a function defining a Tensorflow graph for a GAN generator.
      The weights defined in this graph should already be defined in the given
      checkpoint location. Should have 'mode' as an argument.
    discriminator_fn: a function defining a Tensorflow graph for a GAN
      discriminator. Should have 'mode' as an argument.
    loss_fn: a function defining a Tensorflow graph for a GAN loss. Takes in a
      GANModel tuple, features, labels, and add_summaries as inputs.
    optimizer: a tf.Optimizer or a function that returns a tf.Optimizer with no
      inputs.
   params: An object containing the following parameters:
      - batch_size: an int indicating the size of the training batch.
      - z_shape: the desired shape of the input z values (not counting batch).
      - learning_rate: a scalar or function defining a learning rate applied to
        optimizer.
      - input_clip: the amount to clip the x training variable by.
      - add_summaries: whether or not to add summaries.
      - opt_kwargs: optimizer kwargs.
    config: tf.RunConfig. Should point model to output dir and should indicate
     whether to save checkpoints (to avoid saving checkpoints, set
     save_checkpoints_steps to a number larger than the number of train steps).
     The model_dir field in the RunConfig should point to a directory WITHOUT
     any saved checkpoints.
    ckpt_dir: the directory where the model checkpoints live. The checkpoint is
     used to warm start the underlying GAN. This should NOT be the same as
     config.model_dir.
    warmstart_options: boolean, None, or a WarmStartSettings object. If set to
      True, uses a default WarmStartSettings object. If set to False or None,
      does not use warm start. If using a custom WarmStartSettings object, make
      sure that new variables are properly accounted for when reloading the
      underlying GAN. Defaults to True.
  Returns:
    An estimator spec defining a GAN input training estimator.
  """
  model_fn = _get_latent_gan_model_fn(generator_fn, discriminator_fn,
                                      loss_fn, optimizer)

  if isinstance(warmstart_options, estimator.WarmStartSettings):
    ws = warmstart_options
  elif warmstart_options:
    # Default WarmStart loads all variable names except INPUT_NAME and
    # OPTIMIZER_NAME.
    var_regex = '^(?!.*(%s|%s).*)' % (INPUT_NAME, OPTIMIZER_NAME)
    ws = estimator.WarmStartSettings(ckpt_to_initialize_from=ckpt_dir,
                                     vars_to_warm_start=var_regex)
  else:
    ws = None

  if 'opt_kwargs' not in params:
    params['opt_kwargs'] = {}

  return estimator.Estimator(model_fn=model_fn, config=config, params=params,
                             warm_start_from=ws)