def get_tpu_estimator(generator, discriminator, hparams, config): return tfgan.estimator.TPUGANEstimator( generator_fn=generator, discriminator_fn=discriminator, generator_loss_fn=tfgan.losses.wasserstein_hinge_generator_loss, discriminator_loss_fn=tfgan.losses. wasserstein_hinge_discriminator_loss, generator_optimizer=tf.compat.v1.train.AdamOptimizer( hparams.generator_lr, hparams.beta1), discriminator_optimizer=tf.compat.v1.train.AdamOptimizer( hparams.discriminator_lr, hparams.beta1), prepare_arguments_for_eval_metric_fn=prepare_metric_arguments, get_eval_metric_ops_fn=functools.partial(get_metrics, hparams=hparams), eval_on_tpu=hparams.debug_params.eval_on_tpu, # Define update schedule gan_train_steps=tfgan_tuples.GANTrainSteps(flags.FLAGS.G_steps, flags.FLAGS.D_steps), train_batch_size=hparams.train_batch_size, eval_batch_size=hparams.eval_batch_size, predict_batch_size=hparams.predict_batch_size, use_tpu=hparams.debug_params.use_tpu, config=config, params=hparams._asdict(), warm_start_from=os.path.join(hparams.model_dir, hparams.load_dir) if hparams.load_dir else None)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): if mode not in [ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, model_fn_lib.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) if mode is model_fn_lib.ModeKeys.TRAIN: is_training = True else: is_training = False hparams = hparams_lib.copy_hparams(hparams) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism reuse = tf.get_variable_scope().reuse # Instantiate model self = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams, _reuse=reuse) generator_inputs = self.sample_noise() # rename inputs for clarity real_data = features['inputs'] img_shape = common_layers.shape_list(real_data)[1:4] real_data.set_shape([hparams.batch_size] + img_shape) # To satify the TFGAN API setting real data to none on predict if mode == tf.estimator.ModeKeys.PREDICT: real_data = None optimizers = Optimizers( tf.compat.v1.train.AdamOptimizer(hparams.generator_lr, hparams.beta1), tf.compat.v1.train.AdamOptimizer(hparams.discriminator_lr, hparams.beta1)) # Creates tfhub modules for both generator and discriminator def make_discriminator_spec(): input_layer = tf.placeholder(tf.float32, shape=[None] + img_shape) disc_output = self.discriminator(input_layer, None, mode) hub.add_signature(inputs=input_layer, outputs=disc_output) disc_spec = hub.create_module_spec(make_discriminator_spec) def make_generator_spec(): input_layer = tf.placeholder( tf.float32, shape=[None] + common_layers.shape_list(generator_inputs)[1:]) gen_output = self.generator(input_layer, mode) hub.add_signature(inputs=input_layer, outputs=gen_output) gen_spec = hub.create_module_spec(make_generator_spec) # Create the modules discriminator_module = hub.Module(disc_spec, name="Discriminator_Module", trainable=True) generator_module = hub.Module(gen_spec, name="Generator_Module", trainable=True) # Wraps the modules into functions expected by TF-GAN def generator(code, mode): p = hparams out = generator_module(code) shape = common_layers.shape_list(out) # Applying convolution by PSF convolution if p.apply_psf and 'psf' in features: out = convolve(out, tf.cast(features['psf'][..., 0], tf.complex64)) # Adds noise according to the provided power spectrum noise = tf.spectral.rfft2d(tf.random_normal(out.get_shape()[:3])) thresholded_ps = tf.where(features['ps'] >= 9, tf.zeros_like(features['ps']), tf.sqrt(tf.exp(features['ps']))) noise = noise * tf.cast(thresholded_ps, tf.complex64) out = out + tf.expand_dims(tf.spectral.irfft2d(noise), axis=-1) return out discriminator = lambda image, conditioning, mode: discriminator_module( image) # Make GANModel, which encapsulates the GAN model architectures. gan_model = get_gan_model(mode, generator, discriminator, real_data, generator_inputs, add_summaries=self.summaries) # Make GANLoss, which encapsulates the losses. if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: gan_loss = tfgan_train.gan_loss(gan_model, self.generator_loss, self.discriminator_loss, add_summaries=True) # Make the EstimatorSpec, which incorporates the GANModel, losses, eval # metrics, and optimizers (if required). if mode == tf.estimator.ModeKeys.TRAIN: get_hooks_fn = tfgan_train.get_sequential_train_hooks( namedtuples.GANTrainSteps(hparams.gen_steps, hparams.disc_steps)) estimator_spec = get_train_estimator_spec(gan_model, gan_loss, optimizers, get_hooks_fn, is_chief=True) elif mode == tf.estimator.ModeKeys.EVAL: estimator_spec = get_eval_estimator_spec(gan_model, gan_loss) else: # tf.estimator.ModeKeys.PREDICT # Register hub modules for export hub.register_module_for_export(generator_module, "generator") hub.register_module_for_export(discriminator_module, "discriminator") estimator_spec = get_predict_estimator_spec(gan_model) return estimator_spec
def __init__( self, # Arguments to construct the `model_fn`. generator_fn=None, discriminator_fn=None, generator_loss_fn=None, discriminator_loss_fn=None, generator_optimizer=None, discriminator_optimizer=None, prepare_arguments_for_eval_metric_fn=None, get_eval_metric_ops_fn=None, add_summaries=None, joint_train=False, gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1), # TPUEstimator options. model_dir=None, config=None, params=None, use_tpu=True, train_batch_size=None, eval_batch_size=None, predict_batch_size=None, batch_axis=None, eval_on_tpu=True, export_to_tpu=True, warm_start_from=None): """Initializes a TPUGANEstimator instance. Args: generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN generator. See `TFGAN` for more details and examples. Additionally, if it has an argument called `mode`, the Estimator's `mode` will be passed in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details and examples. generator_loss_fn: The loss function on the generator. Takes a `GANModel` tuple. discriminator_loss_fn: The loss function on the discriminator. Takes a `GANModel` tuple. generator_optimizer: The optimizer for generator updates, or a function that takes no arguments and returns an optimizer. This function will be called when the default graph is the `GANEstimator`'s graph, so utilities like `tf.train.get_or_create_global_step` will work. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. prepare_arguments_for_eval_metric_fn: A function that takes a list of arguments and returns a nested structure of tensors keyed by name. The returned tensors must be compatible with TPUEstimatorSpec.eval_metrics (i.e., in batch-major format, where the batch size is the first dimension) and will be passed to the provided get_eval_metric_ops_fn. The arguments must be: * generator_inputs * generated_data * real_data * discriminator_real_outputs * discriminator_gen_outputs The default impelementation simply returns the arguments as-is. This function is executed on the TPU, allowing for compute-heavy eval-only operations to be performed. get_eval_metric_ops_fn: A function that takes a list of arguments and returns a dict of metric results keyed by name, exectuted on CPU. The arguments of the function should be the keys of the dict returned by prepare_arguments_for_eval_metric_fn (see the prepare_arguments_for_eval_metric_fn for the defaults), and should return a dict from metric string name to the result of calling a metric function, namely a (metric_tensor, update_op) tuple. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. This is ignored for jobs that run on TPU, such as the train job if `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`. joint_train: A Python boolean. If `True`, jointly train the generator and the discriminator. If `False`, sequentially train them. See `train.py` in TFGAN for more details on the differences between the two GAN training methods. gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio of generator to discriminator steps. model_dir: Same as `TPUEstimator`: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in `config` will be used if set. If both are set, they must be same. If both are `None`, a temporary directory will be used. config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration object. Cannot be `None`. params: Same as `TPUEstimator`: An optional `dict` of hyper parameters that will be passed into `input_fn` and `model_fn`. Keys are names of parameters, values are basic python types. There are reserved keys for `TPUEstimator`, including 'batch_size'. If any `params` are args to TF-GAN's `gan_loss`, they will be passed to `gan_loss` during training and evaluation. use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is enabled. Currently, TPU training and evaluation respect this bit, but eval_on_tpu can override execution of eval. See below. Predict still happens on CPU. train_batch_size: Same as `TPUEstimator`: An int representing the global training batch size. TPUEstimator transforms this global batch size to a per-shard batch size, as params['batch_size'], when calling `input_fn` and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be divisible by total number of replicas. eval_batch_size: Same as `TPUEstimator`: An int representing evaluation batch size. Must be divisible by total number of replicas. predict_batch_size: Same as `TPUEstimator`: An int representing the prediction batch size. Must be divisible by total number of replicas. batch_axis: Same as `TPUEstimator`: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. For example, if your input_fn produced (images, labels) where the images tensor is in `HWCN` format, your shard dimensions would be [3, 0], where 3 corresponds to the `N` dimension of your images Tensor, and 0 corresponds to the dimension along which to split the labels to match up with the corresponding images. If None is supplied, and per_host_input_for_training is True, batches will be sharded based on the major dimension. If tpu_config.per_host_input_for_training is False or `PER_HOST_V2`, batch_axis is ignored. eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()` exports a metagraph for serving on TPU besides the one on CPU. warm_start_from: Same as `TPUEstimator`: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string filepath is provided instead of a `WarmStartSettings`, then all variables are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. Raises: ValueError: If loss functions aren't callable. ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps` tuple. ValueError: If `gan_train_steps` isn't 1:1 training. """ _validate_input_args(generator_loss_fn, discriminator_loss_fn, gan_train_steps) loss_fns = LossFns(generator_loss_fn, discriminator_loss_fn) optimizers = Optimizers(generator_optimizer, discriminator_optimizer) # Determine the number of GAN models required to create in order to train # in different D:G ratios on TPU. required_train_models = _required_train_models(gan_train_steps, joint_train) effective_train_batch_size = required_train_models * train_batch_size def _model_fn(features, labels, mode, params): """GANEstimator model function.""" if mode not in [ tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) real_data = labels # rename inputs for clarity generator_inputs = features # rename inputs for clarity # Collect GANModel builder functions, which encapsulate the GAN model # architectures. Don't actually execute them here, since the functions # actually create the TF ops and the variable reads need to be chained # after the writes from the previous step. Instead just pass the functions # with bound arguments down so that they can easily be executed later. gan_model_fns = _get_gan_model_fns( mode, generator_fn, discriminator_fn, real_data, generator_inputs, num_train_models=required_train_models) # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then # remove `add_summaries` logic below. is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) summary_types = None if is_on_tpu else add_summaries # Make the TPUEstimatorSpec, which incorporates the model, losses, eval # metrics, and optimizers (if required). gan_loss_kwargs = gan_estimator.extract_gan_loss_args_from_params( params) if mode == tf.estimator.ModeKeys.TRAIN: estimator_spec = get_train_estimator_spec( gan_model_fns, loss_fns, gan_loss_kwargs, optimizers, joint_train, is_on_tpu, gan_train_steps, add_summaries=summary_types) elif mode == tf.estimator.ModeKeys.EVAL: estimator_spec = get_eval_estimator_spec( gan_model_fns, loss_fns, gan_loss_kwargs, prepare_arguments_for_eval_metric_fn, get_eval_metric_ops_fn, add_summaries=summary_types) else: # predict estimator_spec = get_predict_estimator_spec(gan_model_fns) assert isinstance(estimator_spec, tf.compat.v1.estimator.tpu.TPUEstimatorSpec) return estimator_spec super(TPUGANEstimator, self).__init__(model_fn=_model_fn, model_dir=model_dir, config=config, params=params, use_tpu=use_tpu, train_batch_size=effective_train_batch_size, eval_batch_size=eval_batch_size, predict_batch_size=predict_batch_size, batch_axis=batch_axis, eval_on_tpu=eval_on_tpu, export_to_tpu=export_to_tpu, warm_start_from=warm_start_from)