Exemple #1
0
    def host_call_fn(**kwargs):
        """Host_call_fn.

    Args:
      **kwargs: dict of summary name to tf.Tensor mapping. The value we see here
        is the tensor across all cores, concatenated along axis 0. This function
        will take make a scalar summary that is the mean of the whole tensor (as
        all the values are the same - the mean, trait of
        tpu.CrossShardOptimizer).

    Returns:
      A merged summary op.
    """
        gs = kwargs.pop('global_step')[0]
        with tf_summary.create_file_writer(model_dir).as_default():
            with tf_summary.record_if(tf.equal(gs % 10, 0)):
                for name, tensor in kwargs.items():
                    # Take the mean across cores.
                    tensor = tf.reduce_mean(tensor)
                    tf_summary.scalar(name, tensor, step=gs)
                return tf.summary.all_v2_summary_ops()
Exemple #2
0
def host_call_fn(model_dir, **kwargs):
    """host_call function used for creating training summaries when using TPU.

  Args:
    model_dir: String indicating the output_dir to save summaries in.
    **kwargs: Set of metric names and tensor values for all desired summaries.

  Returns:
    Summary op to be passed to the host_call arg of the estimator function.
  """
    gs = kwargs.pop('global_step')[0]
    with summary.create_file_writer(model_dir).as_default():
        # Always record summaries.
        with summary.record_if(True):
            for name, tensor in kwargs.items():
                if name.startswith(IMG_SUMMARY_PREFIX):
                    summary.image(name.replace(IMG_SUMMARY_PREFIX, ''),
                                  tensor,
                                  max_images=1)
                else:
                    summary.scalar(name, tensor[0], step=gs)
            # Following function is under tf:1x, so we use it.
            return tf.summary.all_v2_summary_ops()