def _get_tensor_summary( name, display_name, description, tensor, content_type, components, json_config ): """Creates a tensor summary with summary metadata. Args: name: Uniquely identifiable name of the summary op. Could be replaced by combination of name and type to make it unique even outside of this summary. display_name: Will be used as the display name in TensorBoard. Defaults to `name`. description: A longform readable description of the summary data. Markdown is supported. tensor: Tensor to display in summary. content_type: Type of content inside the Tensor. components: Bitmask representing present parts (vertices, colors, etc.) that belong to the summary. json_config: A string, JSON-serialized dictionary of ThreeJS classes configuration. Returns: Tensor summary with metadata. """ import torch from tensorboard.plugins.mesh import metadata tensor = torch.as_tensor(tensor) tensor_metadata = metadata.create_summary_metadata( name, display_name, content_type, components, tensor.shape, description, json_config=json_config, ) tensor = TensorProto( dtype="DT_FLOAT", float_val=tensor.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[ TensorShapeProto.Dim(size=tensor.shape[0]), TensorShapeProto.Dim(size=tensor.shape[1]), TensorShapeProto.Dim(size=tensor.shape[2]), ] ), ) tensor_summary = Summary.Value( tag=metadata.get_instance_name(name, content_type), tensor=tensor, metadata=tensor_metadata, ) return tensor_summary
def _add_3d_torch(self, tag, data, step, logdir=None, max_outputs=1, label_to_names=None, description=None): walltime = None if step is None: raise ValueError("Step is not provided or set.") mdata = {} if label_to_names is None else {'label_to_names': label_to_names} summary_metadata = metadata.create_summary_metadata(description=description, metadata=mdata) writer = self._get_file_writer() if logdir is None: logdir = writer.get_logdir() write_dir = PluginDirectory(logdir, metadata.PLUGIN_NAME) geometry_metadata_string = _write_geometry_data(write_dir, tag, step, data, max_outputs) tensor_proto = TensorProto(dtype='DT_STRING', string_val=[geometry_metadata_string], tensor_shape=TensorShapeProto()) writer.add_summary( Summary(value=[ Summary.Value( tag=tag, tensor=tensor_proto, metadata=summary_metadata) ]), step, walltime)
def pr_curve_raw(tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None): if num_thresholds > 127: # weird, value > 127 breaks protobuf num_thresholds = 127 data = np.stack((tp, fp, tn, fn, precision, recall)) pr_curve_plugin_data = PrCurvePluginData( version=0, num_thresholds=num_thresholds).SerializeToString() plugin_data = SummaryMetadata.PluginData(plugin_name="pr_curves", content=pr_curve_plugin_data) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype="DT_FLOAT", float_val=data.reshape(-1).tolist(), tensor_shape=TensorShapeProto(dim=[ TensorShapeProto.Dim(size=data.shape[0]), TensorShapeProto.Dim(size=data.shape[1]), ]), ) return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
def custom_scalars(layout): categories = [] for k, v in layout.items(): charts = [] for chart_name, chart_meatadata in v.items(): tags = chart_meatadata[1] if chart_meatadata[0] == "Margin": assert len(tags) == 3 mgcc = layout_pb2.MarginChartContent(series=[ layout_pb2.MarginChartContent.Series( value=tags[0], lower=tags[1], upper=tags[2]) ]) chart = layout_pb2.Chart(title=chart_name, margin=mgcc) else: mlcc = layout_pb2.MultilineChartContent(tag=tags) chart = layout_pb2.Chart(title=chart_name, multiline=mlcc) charts.append(chart) categories.append(layout_pb2.Category(title=k, chart=charts)) layout = layout_pb2.Layout(category=categories) plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars") smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype="DT_STRING", string_val=[layout.SerializeToString()], tensor_shape=TensorShapeProto(), ) return Summary(value=[ Summary.Value( tag="custom_scalars__config__", tensor=tensor, metadata=smd) ])
def add_3dvolume(self, volume, tag, global_step=None, walltime=None): filename = tag + "_" if global_step is None: filename += datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') else: filename += str(global_step) if isinstance(volume, torch.Tensor): volume = volume.detach().cpu().numpy() img = ants.from_numpy(volume) ants.image_write(img, os.path.join(self._log_dir, filename + ".nii.gz")) plugin_data = tf.SummaryMetadata.PluginData( plugin_name="tb_3d_volume_plugin", content=TextPluginData(version=0).SerializeToString()) metadata = tf.SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype='DT_STRING', string_val=[filename.encode(encoding='utf_8')], tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)])) summary = summary_pb2.Summary(value=[ summary_pb2.Summary.Value( tag=tag, metadata=metadata, tensor=tensor) ]) self._file_writer.add_summary(summary, global_step=global_step, walltime=walltime) self._file_writer.flush()
def make_tensor_summary(name, nparray): tensor_pb = TensorProto( dtype='DT_FLOAT', float_val=nparray.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[TensorShapeProto.Dim(size=s) for s in nparray.shape])) return Summary(value=[Summary.Value(tag=name, tensor=tensor_pb)])
def scalar(name, scalar, collections=None, new_style=False): """Outputs a `Summary` protocol buffer containing a single scalar value. The generated Summary has a Tensor.proto containing the input Tensor. Args: name: A name for the generated node. Will also serve as the series name in TensorBoard. tensor: A real numeric Tensor containing a single value. collections: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. new_style: Whether to use new style (tensor field) or old style (simple_value field). New style could lead to faster data loading. Returns: A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf. Raises: ValueError: If tensor has the wrong shape or type. """ scalar = make_np(scalar) assert scalar.squeeze().ndim == 0, "scalar should be 0D" scalar = float(scalar) if new_style: plugin_data = SummaryMetadata.PluginData(plugin_name="scalars") smd = SummaryMetadata(plugin_data=plugin_data) return Summary(value=[ Summary.Value( tag=name, tensor=TensorProto(float_val=[scalar], dtype="DT_FLOAT"), metadata=smd, ) ]) else: return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
def text(tag, text): plugin_data = SummaryMetadata.PluginData( plugin_name='text', content=TextPluginData(version=0).SerializeToString()) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto(dtype='DT_STRING', string_val=[text.encode(encoding='utf_8')], tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)])) return Summary(value=[Summary.Value(tag=tag + '/text_summary', metadata=smd, tensor=tensor)])
def scalar(name, tensor, collections=None, new_style=False, double_precision=False): """Outputs a `Summary` protocol buffer containing a single scalar value. The generated Summary has a Tensor.proto containing the input Tensor. Args: name: A name for the generated node. Will also serve as the series name in TensorBoard. tensor: A real numeric Tensor containing a single value. collections: Optional list of graph collections keys. The new summary op is added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. new_style: Whether to use new style (tensor field) or old style (simple_value field). New style could lead to faster data loading. Returns: A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf. Raises: ValueError: If tensor has the wrong shape or type. """ tensor = make_np(tensor).squeeze() assert ( tensor.ndim == 0 ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions." # python float is double precision in numpy scalar = float(tensor) if new_style: tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT") if double_precision: tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE") plugin_data = SummaryMetadata.PluginData(plugin_name="scalars") smd = SummaryMetadata(plugin_data=plugin_data) return Summary( value=[ Summary.Value( tag=name, tensor=tensor_proto, metadata=smd, ) ] ) else: return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None): # weird, value > 127 breaks protobuf num_thresholds = min(num_thresholds, 127) data = compute_curve(labels, predictions, num_thresholds=num_thresholds, weights=weights) pr_curve_plugin_data = PrCurvePluginData( version=0, num_thresholds=num_thresholds).SerializeToString() plugin_data = SummaryMetadata.PluginData( plugin_name='pr_curves', content=pr_curve_plugin_data) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto(dtype='DT_FLOAT', float_val=data.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[TensorShapeProto.Dim(size=data.shape[0]), TensorShapeProto.Dim(size=data.shape[1])])) return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])