Example #1
0
  def test_discriminator_only_sees_pool(self):
    """Checks that discriminator only sees pooled values."""
    if tf.executing_eagerly():
      # This test involves '.op', which doesn't work in eager.
      return

    def checker_gen_fn(_):
      return tf.constant(0.0)

    model = tfgan.gan_model(
        checker_gen_fn,
        discriminator_model,
        real_data=tf.zeros([]),
        generator_inputs=tf.random.normal([]))

    def tensor_pool_fn(_):
      return (tf.random.uniform([]), tf.random.uniform([]))

    def checker_dis_fn(inputs, _):
      """Discriminator that checks that it only sees pooled Tensors."""

      def _is_constant(tensor):
        """Returns `True` if the Tensor is a constant."""
        return tensor.op.type == 'Const'

      self.assertFalse(_is_constant(inputs))
      return inputs

    model = model._replace(discriminator_fn=checker_dis_fn)
    tfgan.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
Example #2
0
    def define_model(self, images_x, images_y):
        """Defines a CycleGAN model that maps between images_x and images_y.

        Args:
          images_x: A 4D float `Tensor` of NHWC format.  Images in set X.
          images_y: A 4D float `Tensor` of NHWC format.  Images in set Y.
          use_identity_loss: Whether to use identity loss or not

        Returns:
          A `CycleGANModel` namedtuple.
        """
        if self._swap_inputs:
            generator_inputs = images_y
            real_data = images_x
        else:
            generator_inputs = images_x
            real_data = images_y

        gan_model = tfgan.gan_model(
            generator_fn=_shadowdata_generator_model,
            discriminator_fn=_shadowdata_discriminator_model,
            generator_inputs=generator_inputs,
            real_data=real_data)

        # Add summaries for generated images.
        # tfgan.eval.add_cyclegan_image_summaries(gan_model)
        return gan_model
Example #3
0
  def test_no_shape_check(self):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    def dummy_generator_model(_):
      return (None, None)

    def dummy_discriminator_model(data, conditioning):  # pylint: disable=unused-argument
      return 1

    with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
      tfgan.gan_model(
          dummy_generator_model,
          dummy_discriminator_model,
          real_data=tf.zeros([1, 2]),
          generator_inputs=tf.zeros([1]),
          check_shapes=True)
    tfgan.gan_model(
        dummy_generator_model,
        dummy_discriminator_model,
        real_data=tf.zeros([1, 2]),
        generator_inputs=tf.zeros([1]),
        check_shapes=False)
Example #4
0
  def test_doesnt_crash_when_in_nested_scope(self):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    with tf.compat.v1.variable_scope('outer_scope'):
      gan_model = tfgan.gan_model(
          generator_model,
          discriminator_model,
          real_data=tf.zeros([1, 2]),
          generator_inputs=tf.random.normal([1, 2]))

      # This should work inside a scope.
      tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0)

    # This should also work outside a scope.
    tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0)
Example #5
0
def build_model(stage_id, batch_size, real_images, **kwargs):
    """Builds progressive GAN model.

  Args:
    stage_id: An integer of training stage index.
    batch_size: Number of training images in each minibatch.
    real_images: A 4D `Tensor` of NHWC format.
    **kwargs: A dictionary of
        'start_height': An integer of start image height.
        'start_width': An integer of start image width.
        'scale_base': An integer of resolution multiplier.
        'num_resolutions': An integer of number of progressive resolutions.
        'stable_stage_num_images': An integer of number of training images in
          the stable stage.
        'transition_stage_num_images': An integer of number of training images
          in the transition stage.
        'total_num_images': An integer of total number of training images.
        'kernel_size': Convolution kernel size.
        'colors': Number of image channels.
        'to_rgb_use_tanh_activation': Whether to apply tanh activation when
          output rgb.
        'fmap_base': Base number of filters.
        'fmap_decay': Decay of number of filters.
        'fmap_max': Max number of filters.
        'latent_vector_size': An integer of latent vector size.
        'gradient_penalty_weight': A float of gradient norm target for
          wasserstein loss.
        'gradient_penalty_target': A float of gradient penalty weight for
          wasserstein loss.
        'real_score_penalty_weight': A float of Additional penalty to keep the
          scores from drifting too far from zero.
        'adam_beta1': A float of Adam optimizer beta1.
        'adam_beta2': A float of Adam optimizer beta2.
        'generator_learning_rate': A float of generator learning rate.
        'discriminator_learning_rate': A float of discriminator learning rate.

  Returns:
    An inernal object that wraps all information about the model.
  """
    kernel_size = kwargs['kernel_size']
    colors = kwargs['colors']
    resolution_schedule = make_resolution_schedule(**kwargs)

    num_blocks, num_images = get_stage_info(stage_id, **kwargs)

    current_image_id = tf.compat.v1.train.get_or_create_global_step()
    current_image_id_inc_op = current_image_id.assign_add(batch_size)
    tf.compat.v1.summary.scalar('current_image_id', current_image_id)

    progress = networks.compute_progress(current_image_id,
                                         kwargs['stable_stage_num_images'],
                                         kwargs['transition_stage_num_images'],
                                         num_blocks)
    tf.compat.v1.summary.scalar('progress', progress)

    real_images = networks.blend_images(real_images,
                                        progress,
                                        resolution_schedule,
                                        num_blocks=num_blocks)

    def _num_filters_fn(block_id):
        """Computes number of filters of block `block_id`."""
        return networks.num_filters(block_id, kwargs['fmap_base'],
                                    kwargs['fmap_decay'], kwargs['fmap_max'])

    def _generator_fn(z):
        """Builds generator network."""
        to_rgb_act = tf.tanh if kwargs['to_rgb_use_tanh_activation'] else None
        return networks.generator(z,
                                  progress,
                                  _num_filters_fn,
                                  resolution_schedule,
                                  num_blocks=num_blocks,
                                  kernel_size=kernel_size,
                                  colors=colors,
                                  to_rgb_activation=to_rgb_act)

    def _discriminator_fn(x):
        """Builds discriminator network."""
        return networks.discriminator(x,
                                      progress,
                                      _num_filters_fn,
                                      resolution_schedule,
                                      num_blocks=num_blocks,
                                      kernel_size=kernel_size)

    ########## Define model.
    z = make_latent_vectors(batch_size, **kwargs)

    gan_model = tfgan.gan_model(
        generator_fn=lambda z: _generator_fn(z)[0],
        discriminator_fn=lambda x, unused_z: _discriminator_fn(x)[0],
        real_data=real_images,
        generator_inputs=z)

    ########## Define loss.
    gan_loss = define_loss(gan_model, **kwargs)

    ########## Define train ops.
    gan_train_ops, optimizer_var_list = define_train_ops(
        gan_model, gan_loss, **kwargs)
    gan_train_ops = gan_train_ops._replace(
        global_step_inc_op=current_image_id_inc_op)

    ########## Generator smoothing.
    generator_ema = tf.train.ExponentialMovingAverage(decay=0.999)
    gan_train_ops, generator_vars_to_restore = add_generator_smoothing_ops(
        generator_ema, gan_model, gan_train_ops)

    class Model(object):
        pass

    model = Model()
    model.stage_id = stage_id
    model.batch_size = batch_size
    model.resolution_schedule = resolution_schedule
    model.num_images = num_images
    model.num_blocks = num_blocks
    model.current_image_id = current_image_id
    model.progress = progress
    model.num_filters_fn = _num_filters_fn
    model.generator_fn = _generator_fn
    model.discriminator_fn = _discriminator_fn
    model.gan_model = gan_model
    model.gan_loss = gan_loss
    model.gan_train_ops = gan_train_ops
    model.optimizer_var_list = optimizer_var_list
    model.generator_ema = generator_ema
    model.generator_vars_to_restore = generator_vars_to_restore
    return model
Example #6
0
    def __init__(self, stage_id, batch_size, config):
        """Build graph stage from config dictionary.

    Stage_id and batch_size change during training so they are kept separate
    from the global config. This function is also called by 'load_from_path()'.

    Args:
      stage_id: (int) Build generator/discriminator with this many stages.
      batch_size: (int) Build graph with fixed batch size.
      config: (dict) All the global state.
    """
        data_helper = data_helpers.registry[config['data_type']](config)
        real_images, real_one_hot_labels = data_helper.provide_data(batch_size)

        # gen_one_hot_labels = real_one_hot_labels
        gen_one_hot_labels = data_helper.provide_one_hot_labels(batch_size)
        num_tokens = real_one_hot_labels.shape[1].value

        current_image_id = tf.train.get_or_create_global_step()
        current_image_id_inc_op = current_image_id.assign_add(batch_size)
        tf.summary.scalar('current_image_id', current_image_id)

        train_time = tf.Variable(0., dtype=tf.float32, trainable=False)
        tf.summary.scalar('train_time', train_time)

        resolution_schedule = train_util.make_resolution_schedule(**config)
        num_blocks, num_images = train_util.get_stage_info(stage_id, **config)

        num_stages = (2 * config['num_resolutions']) - 1
        if config['train_time_limit'] is not None:
            stage_times = np.zeros(num_stages, dtype='float32')
            stage_times[0] = 1.
            for i in range(1, num_stages):
                stage_times[i] = (stage_times[i - 1] *
                                  config['train_time_stage_multiplier'])
            stage_times *= config['train_time_limit'] / np.sum(stage_times)
            stage_times = np.cumsum(stage_times)
            print('Stage times:')
            for t in stage_times:
                print('\t{}'.format(t))

        if config['train_progressive']:
            if config['train_time_limit'] is not None:
                progress = networks.compute_progress_from_time(
                    train_time, config['num_resolutions'], num_blocks,
                    stage_times)
            else:
                progress = networks.compute_progress(
                    current_image_id, config['stable_stage_num_images'],
                    config['transition_stage_num_images'], num_blocks)
        else:
            progress = num_blocks - 1.  # Maximum value, must be float.
            num_images = 0
            for stage_id_idx in train_util.get_stage_ids(**config):
                _, n = train_util.get_stage_info(stage_id_idx, **config)
                num_images += n

        # Add to config
        config['resolution_schedule'] = resolution_schedule
        config['num_blocks'] = num_blocks
        config['num_images'] = num_images
        config['progress'] = progress
        config['num_tokens'] = num_tokens
        tf.summary.scalar('progress', progress)

        real_images = networks.blend_images(real_images,
                                            progress,
                                            resolution_schedule,
                                            num_blocks=num_blocks)

        ########## Define model.
        noises = train_util.make_latent_vectors(batch_size, **config)

        # Get network functions and wrap with hparams
        g_fn = lambda x: net_fns.g_fn_registry[config['g_fn']](x, **config)
        d_fn = lambda x: net_fns.d_fn_registry[config['d_fn']](x, **config)

        # Extra lambda functions to conform to tfgan.gan_model interface
        gan_model = tfgan.gan_model(
            generator_fn=lambda inputs: g_fn(inputs)[0],
            discriminator_fn=lambda images, unused_cond: d_fn(images)[0],
            real_data=real_images,
            generator_inputs=(noises, gen_one_hot_labels))

        ########## Define loss.
        gan_loss = train_util.define_loss(gan_model, **config)

        ########## Auxiliary loss functions
        def _compute_ac_loss(images, target_one_hot_labels):
            with tf.variable_scope(gan_model.discriminator_scope, reuse=True):
                _, end_points = d_fn(images)
            return tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=tf.stop_gradient(target_one_hot_labels),
                    logits=end_points['classification_logits']))

        def _compute_gl_consistency_loss(data):
            """G&L consistency loss."""
            sh = data_helper.specgrams_helper
            is_mel = isinstance(data_helper, data_helpers.DataMelHelper)
            if is_mel:
                stfts = sh.melspecgrams_to_stfts(data)
            else:
                stfts = sh.specgrams_to_stfts(data)
            waves = sh.stfts_to_waves(stfts)
            new_stfts = sh.waves_to_stfts(waves)
            # Magnitude loss
            mag = tf.abs(stfts)
            new_mag = tf.abs(new_stfts)
            # Normalize loss to max
            get_max = lambda x: tf.reduce_max(x, axis=(1, 2), keepdims=True)
            mag_max = get_max(mag)
            new_mag_max = get_max(new_mag)
            mag_scale = tf.maximum(1.0, tf.maximum(mag_max, new_mag_max))
            mag_diff = (mag - new_mag) / mag_scale
            mag_loss = tf.reduce_mean(tf.square(mag_diff))
            return mag_loss

        with tf.name_scope('losses'):
            # Loss weights
            gen_ac_loss_weight = config['generator_ac_loss_weight']
            dis_ac_loss_weight = config['discriminator_ac_loss_weight']
            gen_gl_consistency_loss_weight = config[
                'gen_gl_consistency_loss_weight']

            # AC losses.
            fake_ac_loss = _compute_ac_loss(gan_model.generated_data,
                                            gen_one_hot_labels)
            real_ac_loss = _compute_ac_loss(gan_model.real_data,
                                            real_one_hot_labels)

            # GL losses.
            is_fourier = isinstance(data_helper,
                                    (data_helpers.DataSTFTHelper,
                                     data_helpers.DataSTFTNoIFreqHelper,
                                     data_helpers.DataMelHelper))
            if isinstance(data_helper, data_helpers.DataWaveHelper):
                is_fourier = False

            if is_fourier:
                fake_gl_loss = _compute_gl_consistency_loss(
                    gan_model.generated_data)
                real_gl_loss = _compute_gl_consistency_loss(
                    gan_model.real_data)

            # Total losses.
            wx_fake_ac_loss = gen_ac_loss_weight * fake_ac_loss
            wx_real_ac_loss = dis_ac_loss_weight * real_ac_loss
            wx_fake_gl_loss = 0.0
            if (is_fourier and gen_gl_consistency_loss_weight > 0 and stage_id
                    == train_util.get_total_num_stages(**config) - 1):
                wx_fake_gl_loss = fake_gl_loss * gen_gl_consistency_loss_weight
            # Update the loss functions
            gan_loss = gan_loss._replace(
                generator_loss=(gan_loss.generator_loss + wx_fake_ac_loss +
                                wx_fake_gl_loss),
                discriminator_loss=(gan_loss.discriminator_loss +
                                    wx_real_ac_loss))

            tf.summary.scalar('fake_ac_loss', fake_ac_loss)
            tf.summary.scalar('real_ac_loss', real_ac_loss)
            tf.summary.scalar('wx_fake_ac_loss', wx_fake_ac_loss)
            tf.summary.scalar('wx_real_ac_loss', wx_real_ac_loss)
            tf.summary.scalar('total_gen_loss', gan_loss.generator_loss)
            tf.summary.scalar('total_dis_loss', gan_loss.discriminator_loss)

            if is_fourier:
                tf.summary.scalar('fake_gl_loss', fake_gl_loss)
                tf.summary.scalar('real_gl_loss', real_gl_loss)
                tf.summary.scalar('wx_fake_gl_loss', wx_fake_gl_loss)

        ########## Define train ops.
        gan_train_ops, optimizer_var_list = train_util.define_train_ops(
            gan_model, gan_loss, **config)
        gan_train_ops = gan_train_ops._replace(
            global_step_inc_op=current_image_id_inc_op)

        ########## Generator smoothing.
        generator_ema = tf.train.ExponentialMovingAverage(decay=0.999)
        gan_train_ops, generator_vars_to_restore = \
            train_util.add_generator_smoothing_ops(generator_ema,
                                                   gan_model,
                                                   gan_train_ops)
        load_scope = tf.variable_scope(
            gan_model.generator_scope,
            reuse=True,
            custom_getter=train_util.make_var_scope_custom_getter_for_ema(
                generator_ema))

        ########## Separate path for generating samples with a placeholder (ph)
        # Mapping of pitches to one-hot labels
        pitch_counts = data_helper.get_pitch_counts()
        pitch_to_label_dict = {}
        for i, pitch in enumerate(sorted(pitch_counts.keys())):
            pitch_to_label_dict[pitch] = i

        # (label_ph, noise_ph) -> fake_wave_ph
        labels_ph = tf.placeholder(tf.int32, [batch_size])
        noises_ph = tf.placeholder(tf.float32,
                                   [batch_size, config['latent_vector_size']])
        num_pitches = len(pitch_counts)
        one_hot_labels_ph = tf.one_hot(labels_ph, num_pitches)
        with load_scope:
            fake_data_ph, _ = g_fn((noises_ph, one_hot_labels_ph))
            fake_waves_ph = data_helper.data_to_waves(fake_data_ph)

        if config['train_time_limit'] is not None:
            stage_train_time_limit = stage_times[stage_id]
            #  config['train_time_limit'] * \
            # (float(stage_id+1) / ((2*config['num_resolutions'])-1))
        else:
            stage_train_time_limit = None

        ########## Add variables as properties
        self.stage_id = stage_id
        self.batch_size = batch_size
        self.config = config
        self.data_helper = data_helper
        self.resolution_schedule = resolution_schedule
        self.num_images = num_images
        self.num_blocks = num_blocks
        self.current_image_id = current_image_id
        self.progress = progress
        self.generator_fn = g_fn
        self.discriminator_fn = d_fn
        self.gan_model = gan_model
        self.fake_ac_loss = fake_ac_loss
        self.real_ac_loss = real_ac_loss
        self.gan_loss = gan_loss
        self.gan_train_ops = gan_train_ops
        self.optimizer_var_list = optimizer_var_list
        self.generator_ema = generator_ema
        self.generator_vars_to_restore = generator_vars_to_restore
        self.real_images = real_images
        self.real_one_hot_labels = real_one_hot_labels
        self.load_scope = load_scope
        self.pitch_counts = pitch_counts
        self.pitch_to_label_dict = pitch_to_label_dict
        self.labels_ph = labels_ph
        self.noises_ph = noises_ph
        self.fake_waves_ph = fake_waves_ph
        self.saver = tf.train.Saver()
        self.sess = tf.Session()
        self.train_time = train_time
        self.stage_train_time_limit = stage_train_time_limit
Example #7
0
def train(hparams):
    """Trains an MNIST GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    # Force all input processing onto CPU in order to reserve the GPU for
    # the forward inference and back-propagation.
    with tf.name_scope('inputs'), tf.device('/cpu:0'):
        images, one_hot_labels = data_provider.provide_data(
            'train', hparams.batch_size, num_parallel_calls=4)

    # Define the GANModel tuple. Optionally, condition the GAN on the label or
    # use an InfoGAN to learn a latent representation.
    if hparams.gan_type == 'unconditional':
        gan_model = tfgan.gan_model(
            generator_fn=networks.unconditional_generator,
            discriminator_fn=networks.unconditional_discriminator,
            real_data=images,
            generator_inputs=tf.random.normal(
                [hparams.batch_size, hparams.noise_dims]))
    elif hparams.gan_type == 'conditional':
        noise = tf.random.normal([hparams.batch_size, hparams.noise_dims])
        gan_model = tfgan.gan_model(
            generator_fn=networks.conditional_generator,
            discriminator_fn=networks.conditional_discriminator,
            real_data=images,
            generator_inputs=(noise, one_hot_labels))
    elif hparams.gan_type == 'infogan':
        cat_dim, cont_dim = 10, 2
        generator_fn = functools.partial(networks.infogan_generator,
                                         categorical_dim=cat_dim)
        discriminator_fn = functools.partial(networks.infogan_discriminator,
                                             categorical_dim=cat_dim,
                                             continuous_dim=cont_dim)
        unstructured_inputs, structured_inputs = util.get_infogan_noise(
            hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims)
        gan_model = tfgan.infogan_model(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            real_data=images,
            unstructured_generator_inputs=unstructured_inputs,
            structured_generator_inputs=structured_inputs)
    tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size)

    # Get the GANLoss tuple. You can pass a custom function, use one of the
    # already-implemented losses from the losses library, or use the defaults.
    with tf.name_scope('loss'):
        if hparams.gan_type == 'infogan':
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.modified_generator_loss,
                discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
                mutual_information_penalty_weight=1.0,
                add_summaries=True)
        else:
            gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)
        tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = _learning_rate(hparams.gan_type)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    # Run the alternating training loop. Skip it if no steps should be taken
    # (used for graph construction tests).
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if hparams.max_number_of_steps == 0:
        return
    tfgan.gan_train(
        train_ops,
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        logdir=hparams.train_log_dir,
        get_hooks_fn=tfgan.get_joint_train_hooks(),
        save_checkpoint_secs=60)
Example #8
0
def create_callable_gan_model():
  return tfgan.gan_model(
      Generator(),
      Discriminator(),
      real_data=tf.zeros([1, 2]),
      generator_inputs=tf.random.normal([1, 2]))
Example #9
0
def create_gan_model():
  return tfgan.gan_model(
      generator_model,
      discriminator_model,
      real_data=tf.zeros([1, 2]),
      generator_inputs=tf.random.normal([1, 2]))
Example #10
0
def train(hparams):
    """Trains a CIFAR10 GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(
            tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
        # Force all input processing onto CPU in order to reserve the GPU for
        # the forward inference and back-propagation.
        with tf.compat.v1.name_scope('inputs'):
            with tf.device('/cpu:0'):
                images, _ = data_provider.provide_data('train',
                                                       hparams.batch_size,
                                                       num_parallel_calls=4)

        # Define the GANModel tuple.
        generator_fn = networks.generator
        discriminator_fn = networks.discriminator
        generator_inputs = tf.random.normal([hparams.batch_size, 64])
        gan_model = tfgan.gan_model(generator_fn,
                                    discriminator_fn,
                                    real_data=images,
                                    generator_inputs=generator_inputs)
        tfgan.eval.add_gan_model_image_summaries(gan_model)

        # Get the GANLoss tuple. Use the selected GAN loss functions.
        with tf.compat.v1.name_scope('loss'):
            gan_loss = tfgan.gan_loss(gan_model,
                                      gradient_penalty_weight=1.0,
                                      add_summaries=True)

        # Get the GANTrain ops using the custom optimizers and optional
        # discriminator weight clipping.
        with tf.compat.v1.name_scope('train'):
            gen_opt, dis_opt = _get_optimizers(hparams)
            train_ops = tfgan.gan_train_ops(gan_model,
                                            gan_loss,
                                            generator_optimizer=gen_opt,
                                            discriminator_optimizer=dis_opt,
                                            summarize_gradients=True)

        # Run the alternating training loop. Skip it if no steps should be taken
        # (used for graph construction tests).
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.compat.v1.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if hparams.max_number_of_steps == 0:
            return
        tfgan.gan_train(train_ops,
                        hooks=([
                            tf.estimator.StopAtStepHook(
                                num_steps=hparams.max_number_of_steps),
                            tf.estimator.LoggingTensorHook([status_message],
                                                           every_n_iter=10)
                        ]),
                        logdir=hparams.train_log_dir,
                        master=hparams.master,
                        is_chief=hparams.task == 0)