Ejemplo n.º 1
0
def _get_tensor_summary(
    name, display_name, description, tensor, content_type, components, json_config
):
    """Creates a tensor summary with summary metadata.

    Args:
      name: Uniquely identifiable name of the summary op. Could be replaced by
        combination of name and type to make it unique even outside of this
        summary.
      display_name: Will be used as the display name in TensorBoard.
        Defaults to `name`.
      description: A longform readable description of the summary data. Markdown
        is supported.
      tensor: Tensor to display in summary.
      content_type: Type of content inside the Tensor.
      components: Bitmask representing present parts (vertices, colors, etc.) that
        belong to the summary.
      json_config: A string, JSON-serialized dictionary of ThreeJS classes
        configuration.

    Returns:
      Tensor summary with metadata.
    """
    import torch
    from tensorboard.plugins.mesh import metadata

    tensor = torch.as_tensor(tensor)

    tensor_metadata = metadata.create_summary_metadata(
        name,
        display_name,
        content_type,
        components,
        tensor.shape,
        description,
        json_config=json_config,
    )

    tensor = TensorProto(
        dtype="DT_FLOAT",
        float_val=tensor.reshape(-1).tolist(),
        tensor_shape=TensorShapeProto(
            dim=[
                TensorShapeProto.Dim(size=tensor.shape[0]),
                TensorShapeProto.Dim(size=tensor.shape[1]),
                TensorShapeProto.Dim(size=tensor.shape[2]),
            ]
        ),
    )

    tensor_summary = Summary.Value(
        tag=metadata.get_instance_name(name, content_type),
        tensor=tensor,
        metadata=tensor_metadata,
    )

    return tensor_summary
Ejemplo n.º 2
0
def _add_3d_torch(self,
                  tag,
                  data,
                  step,
                  logdir=None,
                  max_outputs=1,
                  label_to_names=None,
                  description=None):
    walltime = None
    if step is None:
        raise ValueError("Step is not provided or set.")

    mdata = {} if label_to_names is None else {'label_to_names': label_to_names}
    summary_metadata = metadata.create_summary_metadata(description=description,
                                                        metadata=mdata)
    writer = self._get_file_writer()
    if logdir is None:
        logdir = writer.get_logdir()
    write_dir = PluginDirectory(logdir, metadata.PLUGIN_NAME)
    geometry_metadata_string = _write_geometry_data(write_dir, tag, step, data,
                                                    max_outputs)
    tensor_proto = TensorProto(dtype='DT_STRING',
                               string_val=[geometry_metadata_string],
                               tensor_shape=TensorShapeProto())

    writer.add_summary(
        Summary(value=[
            Summary.Value(
                tag=tag, tensor=tensor_proto, metadata=summary_metadata)
        ]), step, walltime)
Ejemplo n.º 3
0
def pr_curve_raw(tag,
                 tp,
                 fp,
                 tn,
                 fn,
                 precision,
                 recall,
                 num_thresholds=127,
                 weights=None):
    if num_thresholds > 127:  # weird, value > 127 breaks protobuf
        num_thresholds = 127
    data = np.stack((tp, fp, tn, fn, precision, recall))
    pr_curve_plugin_data = PrCurvePluginData(
        version=0, num_thresholds=num_thresholds).SerializeToString()
    plugin_data = SummaryMetadata.PluginData(plugin_name="pr_curves",
                                             content=pr_curve_plugin_data)
    smd = SummaryMetadata(plugin_data=plugin_data)
    tensor = TensorProto(
        dtype="DT_FLOAT",
        float_val=data.reshape(-1).tolist(),
        tensor_shape=TensorShapeProto(dim=[
            TensorShapeProto.Dim(size=data.shape[0]),
            TensorShapeProto.Dim(size=data.shape[1]),
        ]),
    )
    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
Ejemplo n.º 4
0
def custom_scalars(layout):
    categories = []
    for k, v in layout.items():
        charts = []
        for chart_name, chart_meatadata in v.items():
            tags = chart_meatadata[1]
            if chart_meatadata[0] == "Margin":
                assert len(tags) == 3
                mgcc = layout_pb2.MarginChartContent(series=[
                    layout_pb2.MarginChartContent.Series(
                        value=tags[0], lower=tags[1], upper=tags[2])
                ])
                chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
            else:
                mlcc = layout_pb2.MultilineChartContent(tag=tags)
                chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
            charts.append(chart)
        categories.append(layout_pb2.Category(title=k, chart=charts))

    layout = layout_pb2.Layout(category=categories)
    plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
    smd = SummaryMetadata(plugin_data=plugin_data)
    tensor = TensorProto(
        dtype="DT_STRING",
        string_val=[layout.SerializeToString()],
        tensor_shape=TensorShapeProto(),
    )
    return Summary(value=[
        Summary.Value(
            tag="custom_scalars__config__", tensor=tensor, metadata=smd)
    ])
Ejemplo n.º 5
0
    def add_3dvolume(self, volume, tag, global_step=None, walltime=None):
        filename = tag + "_"
        if global_step is None:
            filename += datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        else:
            filename += str(global_step)

        if isinstance(volume, torch.Tensor):
            volume = volume.detach().cpu().numpy()

        img = ants.from_numpy(volume)
        ants.image_write(img, os.path.join(self._log_dir,
                                           filename + ".nii.gz"))

        plugin_data = tf.SummaryMetadata.PluginData(
            plugin_name="tb_3d_volume_plugin",
            content=TextPluginData(version=0).SerializeToString())
        metadata = tf.SummaryMetadata(plugin_data=plugin_data)
        tensor = TensorProto(
            dtype='DT_STRING',
            string_val=[filename.encode(encoding='utf_8')],
            tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]))
        summary = summary_pb2.Summary(value=[
            summary_pb2.Summary.Value(
                tag=tag, metadata=metadata, tensor=tensor)
        ])
        self._file_writer.add_summary(summary,
                                      global_step=global_step,
                                      walltime=walltime)
        self._file_writer.flush()
Ejemplo n.º 6
0
def make_tensor_summary(name, nparray):
    tensor_pb = TensorProto(
        dtype='DT_FLOAT',
        float_val=nparray.reshape(-1).tolist(),
        tensor_shape=TensorShapeProto(
            dim=[TensorShapeProto.Dim(size=s) for s in nparray.shape]))
    return Summary(value=[Summary.Value(tag=name, tensor=tensor_pb)])
Ejemplo n.º 7
0
def scalar(name, scalar, collections=None, new_style=False):
    """Outputs a `Summary` protocol buffer containing a single scalar value.
    The generated Summary has a Tensor.proto containing the input Tensor.
    Args:
      name: A name for the generated node. Will also serve as the series name in
        TensorBoard.
      tensor: A real numeric Tensor containing a single value.
      collections: Optional list of graph collections keys. The new summary op is
        added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
      new_style: Whether to use new style (tensor field) or old style (simple_value
        field). New style could lead to faster data loading.
    Returns:
      A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
    Raises:
      ValueError: If tensor has the wrong shape or type.
    """
    scalar = make_np(scalar)
    assert scalar.squeeze().ndim == 0, "scalar should be 0D"
    scalar = float(scalar)
    if new_style:
        plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
        smd = SummaryMetadata(plugin_data=plugin_data)
        return Summary(value=[
            Summary.Value(
                tag=name,
                tensor=TensorProto(float_val=[scalar], dtype="DT_FLOAT"),
                metadata=smd,
            )
        ])
    else:
        return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
Ejemplo n.º 8
0
def text(tag, text):
    plugin_data = SummaryMetadata.PluginData(
        plugin_name='text', content=TextPluginData(version=0).SerializeToString())
    smd = SummaryMetadata(plugin_data=plugin_data)
    tensor = TensorProto(dtype='DT_STRING',
                         string_val=[text.encode(encoding='utf_8')],
                         tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]))
    return Summary(value=[Summary.Value(tag=tag + '/text_summary', metadata=smd, tensor=tensor)])
Ejemplo n.º 9
0
def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
    """Outputs a `Summary` protocol buffer containing a single scalar value.
    The generated Summary has a Tensor.proto containing the input Tensor.
    Args:
      name: A name for the generated node. Will also serve as the series name in
        TensorBoard.
      tensor: A real numeric Tensor containing a single value.
      collections: Optional list of graph collections keys. The new summary op is
        added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
      new_style: Whether to use new style (tensor field) or old style (simple_value
        field). New style could lead to faster data loading.
    Returns:
      A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
    Raises:
      ValueError: If tensor has the wrong shape or type.
    """
    tensor = make_np(tensor).squeeze()
    assert (
        tensor.ndim == 0
    ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
    # python float is double precision in numpy
    scalar = float(tensor)
    if new_style:
        tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
        if double_precision:
            tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")

        plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
        smd = SummaryMetadata(plugin_data=plugin_data)
        return Summary(
            value=[
                Summary.Value(
                    tag=name,
                    tensor=tensor_proto,
                    metadata=smd,
                )
            ]
        )
    else:
        return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
Ejemplo n.º 10
0
def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
    # weird, value > 127 breaks protobuf
    num_thresholds = min(num_thresholds, 127)
    data = compute_curve(labels, predictions,
                         num_thresholds=num_thresholds, weights=weights)
    pr_curve_plugin_data = PrCurvePluginData(
        version=0, num_thresholds=num_thresholds).SerializeToString()
    plugin_data = SummaryMetadata.PluginData(
        plugin_name='pr_curves', content=pr_curve_plugin_data)
    smd = SummaryMetadata(plugin_data=plugin_data)
    tensor = TensorProto(dtype='DT_FLOAT',
                         float_val=data.reshape(-1).tolist(),
                         tensor_shape=TensorShapeProto(
                             dim=[TensorShapeProto.Dim(size=data.shape[0]), TensorShapeProto.Dim(size=data.shape[1])]))
    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])