Beispiel #1
0
def make_status_message(model):
    """Makes a string `Tensor` of training status."""
    return tf.strings.join([
        'Starting train step: current_image_id: ',
        tf.as_string(model.current_image_id), ', progress: ',
        tf.as_string(model.progress), ', num_blocks: {}'.format(
            model.num_blocks), ', batch_size: {}'.format(model.batch_size)
    ],
                           name='status_message')
Beispiel #2
0
def op(name,
       images,
       max_outputs=3,
       display_name=None,
       description=None,
       collections=None):
    """Create a legacy image summary op for use in a TensorFlow graph.

  Arguments:
    name: A unique name for the generated summary node.
    images: A `Tensor` representing pixel data with shape `[k, h, w, c]`,
      where `k` is the number of images, `h` and `w` are the height and
      width of the images, and `c` is the number of channels, which
      should be 1, 3, or 4. Any of the dimensions may be statically
      unknown (i.e., `None`).
    max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
      many images will be emitted at each step. When more than
      `max_outputs` many images are provided, the first `max_outputs` many
      images will be used and the rest silently discarded.
    display_name: Optional name for this summary in TensorBoard, as a
      constant `str`. Defaults to `name`.
    description: Optional long-form description for this summary, as a
      constant `str`. Markdown is supported. Defaults to empty.
    collections: Optional list of graph collections keys. The new
      summary op is added to these collections. Defaults to
      `[Graph Keys.SUMMARIES]`.

  Returns:
    A TensorFlow summary op.
  """
    # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
    import tensorflow.compat.v1 as tf

    if display_name is None:
        display_name = name
    summary_metadata = metadata.create_summary_metadata(
        display_name=display_name, description=description)
    with tf.name_scope(name), \
         tf.control_dependencies([tf.assert_rank(images, 4),
                                  tf.assert_type(images, tf.uint8),
                                  tf.assert_non_negative(max_outputs)]):
        limited_images = images[:max_outputs]
        encoded_images = tf.map_fn(tf.image.encode_png,
                                   limited_images,
                                   dtype=tf.string,
                                   name='encode_each_image')
        image_shape = tf.shape(images)
        dimensions = tf.stack([
            tf.as_string(image_shape[2], name='width'),
            tf.as_string(image_shape[1], name='height')
        ],
                              name='dimensions')
        tensor = tf.concat([dimensions, encoded_images], axis=0)
        return tf.summary.tensor_summary(name='image_summary',
                                         tensor=tensor,
                                         collections=collections,
                                         summary_metadata=summary_metadata)
Beispiel #3
0
def _extend_with_dummy(extend_with, to_extend, dummy_value='n/a'):
  """Extends one SparseTensor with dummy_values at positions of other."""
  dense_shape = tf.to_int64(
      tf.concat([[tf.shape(extend_with)[0]],
                 [tf.maximum(tf.shape(extend_with)[1],
                             tf.shape(to_extend)[1])],
                 [tf.maximum(tf.shape(extend_with)[2],
                             tf.shape(to_extend)[2])]],
                axis=0))
  additional_indices = tf.sets.set_difference(
      tf.SparseTensor(
          indices=extend_with.indices,
          values=tf.zeros_like(extend_with.values, dtype=tf.int32),
          dense_shape=dense_shape),
      tf.SparseTensor(
          indices=to_extend.indices,
          values=tf.zeros([tf.shape(to_extend.indices)[0]], dtype=tf.int32),
          dense_shape=dense_shape)).indices
  # Supply defaults for all other indices.
  default = tf.tile(
      tf.constant([dummy_value]), multiples=[tf.shape(additional_indices)[0]])

  string_value = (
      tf.as_string(to_extend.values)
      if to_extend.values.dtype != tf.string else to_extend.values)
  return tf.sparse_reorder(
      tf.SparseTensor(
          indices=tf.concat([to_extend.indices, additional_indices], axis=0),
          values=tf.concat([string_value, default], axis=0),
          dense_shape=dense_shape))
Beispiel #4
0
def _replace_empty_string_with_random_number(string_tensor):
  """Returns string unchanged if non-empty, and random string tensor otherwise.

  The random string is an integer 0 and 2**63 - 1, casted as string.


  Args:
    string_tensor: A tf.tensor of dtype string.

  Returns:
    out_string: A tf.tensor of dtype string. If string_tensor contains the empty
      string, out_string will contain a random integer casted to a string.
      Otherwise string_tensor is returned unchanged.

  """

  empty_string = tf.constant('', dtype=tf.string, name='EmptyString')

  random_source_id = tf.as_string(
      tf.random_uniform(shape=[], maxval=2**63 - 1, dtype=tf.int64))

  out_string = tf.cond(
      tf.equal(string_tensor, empty_string),
      true_fn=lambda: random_source_id,
      false_fn=lambda: string_tensor)

  return out_string
Beispiel #5
0
def sampled_softmax_loss(src_emb,
                         pos_ids,
                         neg_num,
                         output_emb_table,
                         output_emb_bias,
                         node_size,
                         s2h=True):
  """Sampled softmax loss.
  Args:
    src_emb: positive src embedding with shape [batch_size, dim]
    pos_ids: positive ids.
    output_emb_table:
    output_emb_bias:
    node_size: total node size.
    s2h: set True if need string to hash.
  """
  if s2h:
    pos_ids = tf.as_string(pos_ids)
    pos_ids = tf.string_to_hash_bucket_fast(
        pos_ids,
        node_size,
        name='softmax_loss_to_hash_bucket_oper')

  loss = tf.nn.sampled_softmax_loss(
      weights=output_emb_table,
      biases=output_emb_bias,
      labels=tf.reshape(pos_ids, [-1, 1]),
      inputs=src_emb,
      num_sampled=neg_num,
      num_classes=node_size,
      partition_strategy='mod',
      remove_accidental_hits=True)

  return [tf.reduce_mean(loss), None, None]
Beispiel #6
0
def _dedup_tensor(sp_tensor: tf.SparseTensor) -> tf.SparseTensor:
  """Dedup values of a SparseTensor along each row.

  Args:
    sp_tensor: A 2D SparseTensor to be deduped.
  Returns:
    A deduped SparseTensor of shape [batch_size, max_len], where max_len is
    the maximum number of unique values for a row in the Tensor.
  """
  string_batch_index = tf.as_string(sp_tensor.indices[:, 0])

  # tf.unique only works on 1D tensors. To avoid deduping across examples,
  # prepend each feature value with the example index. This requires casting
  # to and from strings for non-string features.
  string_values = sp_tensor.values
  original_dtype = sp_tensor.values.dtype
  if original_dtype != tf.string:
    string_values = tf.as_string(sp_tensor.values)
  index_and_value = tf.strings.join([string_batch_index, string_values],
                                    separator='|')
  unique_index_and_value, _ = tf.unique(index_and_value)

  # split is a shape [tf.size(values), 2] tensor. The first column contains
  # indices and the second column contains the feature value (we assume no
  # feature contains | so we get exactly 2 values from the string split).
  split = tf.string_split(unique_index_and_value, delimiter='|')
  split = tf.reshape(split.values, [-1, 2])
  string_indices = split[:, 0]
  values = split[:, 1]

  indices = tf.reshape(
      tf.string_to_number(string_indices, out_type=tf.int32), [-1])
  if original_dtype != tf.string:
    values = tf.string_to_number(values, out_type=original_dtype)
  values = tf.reshape(values, [-1])
  # Convert example indices into SparseTensor indices, e.g.
  # [0, 0, 0, 1, 3, 3] -> [[0,0], [0,1], [0,2], [1,0], [3,0], [3,1]]
  batch_size = tf.to_int32(sp_tensor.dense_shape[0])
  new_indices, max_len = _example_index_to_sparse_index(indices, batch_size)
  return tf.SparseTensor(
      indices=tf.to_int64(new_indices),
      values=values,
      dense_shape=[tf.to_int64(batch_size), max_len])
Beispiel #7
0
def train(hparams):
    """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(tf.train.replica_device_setter(hparams.ps_replicas)):
        with tf.name_scope('inputs'), tf.device('/cpu:0'):
            images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                           hparams.image_set_y_file_pattern,
                                           hparams.batch_size,
                                           hparams.patch_size)

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=hparams.
            cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if not hparams.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            hparams.train_log_dir,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                tf.estimator.StopAtStepHook(
                    num_steps=hparams.max_number_of_steps),
                tf.estimator.LoggingTensorHook(
                    {'status_message': status_message}, every_n_iter=10)
            ],
            master=hparams.master,
            is_chief=hparams.task == 0)
Beispiel #8
0
    def encode(self, ego_tensor):
        if not self._use_edge:
            ids = ego_tensor.src.ids
        else:
            # TODO: replace this hack with edge ids.
            ids = tf.cast(ego_tensor.src.continuous_attrs, dtype=tf.int64)

        if self._str2hash:
            index = tf.as_string(ids)
            ids = tf.string_to_hash_bucket_fast(index,
                                                self._num,
                                                name=self._name +
                                                'str_to_hash_bucket_op')

        emb = tf.nn.embedding_lookup(self._emb_table,
                                     ids,
                                     name=self._name +
                                     'ids_embedding_lookup_op')
        emb = tf.reshape(emb, [-1, self._dim])

        return emb
# A quick fix to run TF 1.X code in TF 2.X, we may want to properly migrate the Python script to TF 2.X API.
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os

a = tf.placeholder(tf.float32, shape=None, name="a")
b = tf.placeholder(tf.float32, shape=None, name="b")
e = tf.placeholder(tf.string, shape=None, name="e")
ee = tf.strings.to_number(e, out_type=tf.float32, name="ee")
v = tf.Variable(dtype=tf.float32, initial_value=tf.constant(1.0), name="v")
c = tf.add(a, b, name="c")
d = tf.add(c, v)
eee = tf.add(ee, v)
e4 = tf.as_string(eee)
global_step = tf.train.get_or_create_global_step()
global_step_inc = tf.assign_add(global_step, 1)
hooks = [tf.train.StopAtStepHook(last_step=2)]
with tf.Session() as mon_sess:
    mon_sess.run(tf.initialize_all_variables())
    for i in range(2):
        print(
            mon_sess.run([d, e4, global_step_inc],
                         feed_dict={
                             a: [1.0, 2.0, 3.0],
                             b: [1.0, 2.0, 3.0],
                             e: ["1.0", "2.0", "3.0"]
                         }))

    signatures = {
        tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
            'inputs': {
Beispiel #10
0
def train(hparams):
    """Trains a StarGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """

    # Create the log_dir if not exist.
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    # Shard the model to different parameter servers.
    with tf.device(tf.train.replica_device_setter(hparams.ps_replicas)):

        # Create the input dataset.
        with tf.name_scope('inputs'), tf.device('/cpu:0'):
            images, labels = data_provider.provide_data(
                'train', hparams.batch_size, hparams.patch_size)

        # Define the model.
        with tf.name_scope('model'):
            model = _define_model(images, labels)

        # Add image summary.
        tfgan.eval.add_stargan_image_summaries(model,
                                               num_images=3 *
                                               hparams.batch_size,
                                               display_diffs=True)

        # Define the model loss.
        loss = tfgan.stargan_loss(model)

        # Define the train ops.
        with tf.name_scope('train_ops'):
            train_ops = _define_train_ops(model, loss, hparams.generator_lr,
                                          hparams.discriminator_lr,
                                          hparams.adam_beta1,
                                          hparams.adam_beta2,
                                          hparams.max_number_of_steps)

        # Define the train steps.
        train_steps = _define_train_step(hparams.gen_disc_step_ratio)

        # Define a status message.
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.train.get_or_create_global_step())
        ],
                                         name='status_message')

        # Train the model.
        tfgan.gan_train(
            train_ops,
            hparams.train_log_dir,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                tf.estimator.StopAtStepHook(
                    num_steps=hparams.max_number_of_steps),
                tf.estimator.LoggingTensorHook([status_message],
                                               every_n_iter=10)
            ],
            master=hparams.tf_master,
            is_chief=hparams.task == 0)
Beispiel #11
0
def train(hparams):
    """Trains an MNIST GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    # Force all input processing onto CPU in order to reserve the GPU for
    # the forward inference and back-propagation.
    with tf.name_scope('inputs'), tf.device('/cpu:0'):
        images, one_hot_labels = data_provider.provide_data(
            'train', hparams.batch_size, num_parallel_calls=4)

    # Define the GANModel tuple. Optionally, condition the GAN on the label or
    # use an InfoGAN to learn a latent representation.
    if hparams.gan_type == 'unconditional':
        gan_model = tfgan.gan_model(
            generator_fn=networks.unconditional_generator,
            discriminator_fn=networks.unconditional_discriminator,
            real_data=images,
            generator_inputs=tf.random.normal(
                [hparams.batch_size, hparams.noise_dims]))
    elif hparams.gan_type == 'conditional':
        noise = tf.random.normal([hparams.batch_size, hparams.noise_dims])
        gan_model = tfgan.gan_model(
            generator_fn=networks.conditional_generator,
            discriminator_fn=networks.conditional_discriminator,
            real_data=images,
            generator_inputs=(noise, one_hot_labels))
    elif hparams.gan_type == 'infogan':
        cat_dim, cont_dim = 10, 2
        generator_fn = functools.partial(networks.infogan_generator,
                                         categorical_dim=cat_dim)
        discriminator_fn = functools.partial(networks.infogan_discriminator,
                                             categorical_dim=cat_dim,
                                             continuous_dim=cont_dim)
        unstructured_inputs, structured_inputs = util.get_infogan_noise(
            hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims)
        gan_model = tfgan.infogan_model(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            real_data=images,
            unstructured_generator_inputs=unstructured_inputs,
            structured_generator_inputs=structured_inputs)
    tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size)

    # Get the GANLoss tuple. You can pass a custom function, use one of the
    # already-implemented losses from the losses library, or use the defaults.
    with tf.name_scope('loss'):
        if hparams.gan_type == 'infogan':
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.modified_generator_loss,
                discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
                mutual_information_penalty_weight=1.0,
                add_summaries=True)
        else:
            gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)
        tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = _learning_rate(hparams.gan_type)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    # Run the alternating training loop. Skip it if no steps should be taken
    # (used for graph construction tests).
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if hparams.max_number_of_steps == 0:
        return
    tfgan.gan_train(
        train_ops,
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        logdir=hparams.train_log_dir,
        get_hooks_fn=tfgan.get_joint_train_hooks(),
        save_checkpoint_secs=60)
Beispiel #12
0
    def _process(examples):
        """Supplies input to our model.

    This function supplies input to our model after parsing.

    Args:
      examples: The dictionary from key to (Sparse)Tensors with context
        and sequence features.

    Returns:
      A tuple consisting of 1) a dictionary of tensors whose keys are
      the feature names, and 2) a tensor of target labels if the mode
      is not INFER (and None, otherwise).
    """
        # Combine into a single dictionary.
        feature_map = {}
        # Add age if requested.
        if include_age:
            age_in_seconds = (
                examples[CONTEXT_KEY_PREFIX + 'timestamp'] -
                examples.pop(CONTEXT_KEY_PREFIX + 'Patient.birthDate'))
            age_in_years = tf.to_float(age_in_seconds) / (60 * 60 * 24 * 365.0)
            feature_map[CONTEXT_KEY_PREFIX + AGE_KEY] = age_in_years

        sequence_length = examples.pop(CONTEXT_KEY_PREFIX + 'sequenceLength')
        # Cross the requested features.
        for cross in time_crossed_features:
            # The features may be missing at different rates - we take the union
            # of the indices supplying defaults.
            extended_features = dict()
            dense_shape = tf.concat(
                [[tf.to_int64(tf.shape(sequence_length)[0])],
                 [tf.reduce_max(sequence_length)],
                 tf.constant([1], dtype=tf.int64)],
                axis=0)
            for i, feature in enumerate(cross):
                sp_tensor = examples[SEQUENCE_KEY_PREFIX + feature]
                additional_indices = []
                covered_indices = sp_tensor.indices
                for j, other_feature in enumerate(cross):
                    if i != j:
                        additional_indices.append(
                            tf.sets.set_difference(
                                tf.sparse_reorder(
                                    tf.SparseTensor(
                                        indices=examples[
                                            SEQUENCE_KEY_PREFIX +
                                            other_feature].indices,
                                        values=tf.zeros([
                                            tf.shape(examples[
                                                SEQUENCE_KEY_PREFIX +
                                                other_feature].indices)[0]
                                        ],
                                                        dtype=tf.int32),
                                        dense_shape=dense_shape)),
                                tf.sparse_reorder(
                                    tf.SparseTensor(
                                        indices=covered_indices,
                                        values=tf.zeros(
                                            [tf.shape(covered_indices)[0]],
                                            dtype=tf.int32),
                                        dense_shape=dense_shape))).indices)
                        covered_indices = tf.concat([sp_tensor.indices] +
                                                    additional_indices,
                                                    axis=0)

                additional_indices = tf.concat(additional_indices, axis=0)

                # Supply defaults for all other indices.
                default = tf.tile(tf.constant(['n/a']),
                                  multiples=[tf.shape(additional_indices)[0]])

                string_value = sp_tensor.values
                if string_value.dtype != tf.string:
                    string_value = tf.as_string(string_value)

                extended_features[feature] = tf.sparse_reorder(
                    tf.SparseTensor(indices=tf.concat(
                        [sp_tensor.indices, additional_indices], axis=0),
                                    values=tf.concat([string_value, default],
                                                     axis=0),
                                    dense_shape=dense_shape))

            new_values = tf.strings.join(
                [extended_features[f].values for f in cross], separator='-')
            crossed_sp_tensor = tf.sparse_reorder(
                tf.SparseTensor(
                    indices=extended_features[cross[0]].indices,
                    values=new_values,
                    dense_shape=extended_features[cross[0]].dense_shape))
            examples[SEQUENCE_KEY_PREFIX + '_'.join(cross)] = crossed_sp_tensor
        # Remove unwanted features that are used in the cross but should not be
        # considered outside the cross.
        for cross in time_crossed_features:
            for feature in cross:
                if (feature not in sequence_features
                        and SEQUENCE_KEY_PREFIX + feature in examples):
                    del examples[SEQUENCE_KEY_PREFIX + feature]

        # Flatten sparse tensor to compute event age. This dense tensor also
        # contains padded values. These will not be used when gathering elements
        # from the dense tensor since each sparse feature won't have a value
        # defined for the padding.
        padded_event_age = (
            # Broadcast current time along sequence dimension.
            tf.expand_dims(examples.pop(CONTEXT_KEY_PREFIX + 'timestamp'), 1)
            # Subtract time of events.
            - examples.pop(SEQUENCE_KEY_PREFIX + 'eventId'))

        for i in range(len(time_windows) - 1):
            max_age = time_windows[i]
            min_age = time_windows[i + 1]
            padded_in_time_window = tf.logical_and(padded_event_age <= max_age,
                                                   padded_event_age > min_age)

            for k, v in examples.iteritems():
                if k.startswith(CONTEXT_KEY_PREFIX):
                    continue
                # For each sparse feature entry, look up whether it is in the time
                # window or not.
                in_time_window = tf.gather_nd(padded_in_time_window,
                                              v.indices[:, 0:2])
                v = tf.sparse_retain(v, in_time_window)
                sp_tensor = tf.sparse_reshape(v, [v.dense_shape[0], -1])
                if dedup:
                    sp_tensor = _dedup_tensor(sp_tensor)

                feature_map[k + '-til-%d' % min_age] = sp_tensor

        for k, v in examples.iteritems():
            if k.startswith(CONTEXT_KEY_PREFIX):
                feature_map[k] = v
        return feature_map
Beispiel #13
0
def train(hparams):
    """Trains a CIFAR10 GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(tf.train.replica_device_setter(hparams.ps_replicas)):
        # Force all input processing onto CPU in order to reserve the GPU for
        # the forward inference and back-propagation.
        with tf.name_scope('inputs'):
            with tf.device('/cpu:0'):
                images, _ = data_provider.provide_data('train',
                                                       hparams.batch_size,
                                                       num_parallel_calls=4)

        # Define the GANModel tuple.
        generator_fn = networks.generator
        discriminator_fn = networks.discriminator
        generator_inputs = tf.random.normal([hparams.batch_size, 64])
        gan_model = tfgan.gan_model(generator_fn,
                                    discriminator_fn,
                                    real_data=images,
                                    generator_inputs=generator_inputs)
        tfgan.eval.add_gan_model_image_summaries(gan_model)

        # Get the GANLoss tuple. Use the selected GAN loss functions.
        with tf.name_scope('loss'):
            gan_loss = tfgan.gan_loss(gan_model,
                                      gradient_penalty_weight=1.0,
                                      add_summaries=True)

        # Get the GANTrain ops using the custom optimizers and optional
        # discriminator weight clipping.
        with tf.name_scope('train'):
            gen_opt, dis_opt = _get_optimizers(hparams)
            train_ops = tfgan.gan_train_ops(gan_model,
                                            gan_loss,
                                            generator_optimizer=gen_opt,
                                            discriminator_optimizer=dis_opt,
                                            summarize_gradients=True)

        # Run the alternating training loop. Skip it if no steps should be taken
        # (used for graph construction tests).
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if hparams.max_number_of_steps == 0:
            return
        tfgan.gan_train(train_ops,
                        hooks=([
                            tf.estimator.StopAtStepHook(
                                num_steps=hparams.max_number_of_steps),
                            tf.estimator.LoggingTensorHook([status_message],
                                                           every_n_iter=10)
                        ]),
                        logdir=hparams.train_log_dir,
                        master=hparams.master,
                        is_chief=hparams.task == 0)