예제 #1
0
def write_summary_tensor(step_i: int, key: str, tensor: JTensor,
                         summary_type: SummaryType) -> bool:
    """Writes summary in relevant processes."""
    if summary_type == SummaryType.SCALAR:
        tensor = np.mean(tensor).item()
        tf_summary.scalar(key, tensor, step_i)
    elif summary_type == SummaryType.IMAGE:
        # Some eval codepath adds a leading 'test split' dim.
        tensor = np.reshape(tensor, [-1] + list(tensor.shape)[-3:])
        # Create a separate key for each image to avoid RPC oversize issues.
        for i in range(max(tensor.shape[0], MAX_IMAGES_PER_SUMMARY)):
            tf_summary.image('%s_%d' % (key, i), tensor[i:i + 1], step_i)
    else:
        assert False, 'Unsupported summary type: ' + str(summary_type)
예제 #2
0
def side_by_side_frames(name, tensors):
    """Visualizes frames side by side.

  Args:
    name: name of the summary.
    tensors: a list of video tensors to be merged side by side.

  Returns:
    the summary result.
  """
    x = tf.concat(tensors, axis=3)
    x = tf.concat(tf.unstack(x, axis=1), axis=1)
    return tfs.image(name, x)
예제 #3
0
def host_call_fn(model_dir, **kwargs):
    """host_call function used for creating training summaries when using TPU.

  Args:
    model_dir: String indicating the output_dir to save summaries in.
    **kwargs: Set of metric names and tensor values for all desired summaries.

  Returns:
    Summary op to be passed to the host_call arg of the estimator function.
  """
    gs = kwargs.pop('global_step')[0]
    with summary.create_file_writer(model_dir).as_default():
        # Always record summaries.
        with summary.record_if(True):
            for name, tensor in kwargs.items():
                if name.startswith(IMG_SUMMARY_PREFIX):
                    summary.image(name.replace(IMG_SUMMARY_PREFIX, ''),
                                  tensor,
                                  max_images=1)
                else:
                    summary.scalar(name, tensor[0], step=gs)
            # Following function is under tf:1x, so we use it.
            return tf.summary.all_v2_summary_ops()
예제 #4
0
def py_plot_1d_signal(name, signals, labels, max_outputs=3, step=None):
    """Visualizes a list of 1d signals.

  Args:
    name: name of the summary.
    signals: a [batch, lines, steps] np.array list of 1d arrays.
    labels: a [lines] list of labels for each signal.
    max_outputs: the maximum number of plots to add to summaries.
    step: an explicit step or None.

  Returns:
    the summary result.
  """
    image = plot_1d_signals(signals, labels, min(max_outputs,
                                                 signals.shape[0]))
    return tfs.image(name, image, step, max_outputs=max_outputs)
예제 #5
0
def tf_plot_1d_signal(name, signals, labels, max_outputs=3, step=None):
    """Visualizes a list of 1d signals.

  Args:
    name: name of the summary.
    signals: a [batch, lines, steps] tensor, each line a 1d signal.
    labels: a [lines] list of labels for each signal.
    max_outputs: the maximum number of plots to add to summaries.
    step: an explicit step or None.

  Returns:
    the summary result.
  """
    image = tf.py_function(
        plot_1d_signals,
        (signals, labels, tf.math.minimum(max_outputs,
                                          tf.shape(signals)[0])), tf.uint8)
    return tfs.image(name, image, step, max_outputs=max_outputs)