def make_tensor_summary(name, nparray): tensor_pb = TensorProto( dtype='DT_FLOAT', float_val=nparray.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[TensorShapeProto.Dim(size=s) for s in nparray.shape])) return Summary(value=[Summary.Value(tag=name, tensor=tensor_pb)])
def pr_curve_raw(tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None): if num_thresholds > 127: # weird, value > 127 breaks protobuf num_thresholds = 127 data = np.stack((tp, fp, tn, fn, precision, recall)) pr_curve_plugin_data = PrCurvePluginData( version=0, num_thresholds=num_thresholds).SerializeToString() plugin_data = SummaryMetadata.PluginData(plugin_name="pr_curves", content=pr_curve_plugin_data) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype="DT_FLOAT", float_val=data.reshape(-1).tolist(), tensor_shape=TensorShapeProto(dim=[ TensorShapeProto.Dim(size=data.shape[0]), TensorShapeProto.Dim(size=data.shape[1]), ]), ) return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
def add_3dvolume(self, volume, tag, global_step=None, walltime=None): filename = tag + "_" if global_step is None: filename += datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') else: filename += str(global_step) if isinstance(volume, torch.Tensor): volume = volume.detach().cpu().numpy() img = ants.from_numpy(volume) ants.image_write(img, os.path.join(self._log_dir, filename + ".nii.gz")) plugin_data = tf.SummaryMetadata.PluginData( plugin_name="tb_3d_volume_plugin", content=TextPluginData(version=0).SerializeToString()) metadata = tf.SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype='DT_STRING', string_val=[filename.encode(encoding='utf_8')], tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)])) summary = summary_pb2.Summary(value=[ summary_pb2.Summary.Value( tag=tag, metadata=metadata, tensor=tensor) ]) self._file_writer.add_summary(summary, global_step=global_step, walltime=walltime) self._file_writer.flush()
def text(tag, text): plugin_data = SummaryMetadata.PluginData( plugin_name='text', content=TextPluginData(version=0).SerializeToString()) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto(dtype='DT_STRING', string_val=[text.encode(encoding='utf_8')], tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)])) return Summary(value=[Summary.Value(tag=tag + '/text_summary', metadata=smd, tensor=tensor)])
def _get_tensor_summary( name, display_name, description, tensor, content_type, components, json_config ): """Creates a tensor summary with summary metadata. Args: name: Uniquely identifiable name of the summary op. Could be replaced by combination of name and type to make it unique even outside of this summary. display_name: Will be used as the display name in TensorBoard. Defaults to `name`. description: A longform readable description of the summary data. Markdown is supported. tensor: Tensor to display in summary. content_type: Type of content inside the Tensor. components: Bitmask representing present parts (vertices, colors, etc.) that belong to the summary. json_config: A string, JSON-serialized dictionary of ThreeJS classes configuration. Returns: Tensor summary with metadata. """ import torch from tensorboard.plugins.mesh import metadata tensor = torch.as_tensor(tensor) tensor_metadata = metadata.create_summary_metadata( name, display_name, content_type, components, tensor.shape, description, json_config=json_config, ) tensor = TensorProto( dtype="DT_FLOAT", float_val=tensor.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[ TensorShapeProto.Dim(size=tensor.shape[0]), TensorShapeProto.Dim(size=tensor.shape[1]), TensorShapeProto.Dim(size=tensor.shape[2]), ] ), ) tensor_summary = Summary.Value( tag=metadata.get_instance_name(name, content_type), tensor=tensor, metadata=tensor_metadata, ) return tensor_summary
def tensor_shape_proto(shape): r"""Creates a shape opbject. Args: shape (tuple of int): A tuple of integers. Returns: TensorShapeProto: A Tesorshape. """ return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in shape])
def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None): # weird, value > 127 breaks protobuf num_thresholds = min(num_thresholds, 127) data = compute_curve(labels, predictions, num_thresholds=num_thresholds, weights=weights) pr_curve_plugin_data = PrCurvePluginData( version=0, num_thresholds=num_thresholds).SerializeToString() plugin_data = SummaryMetadata.PluginData( plugin_name='pr_curves', content=pr_curve_plugin_data) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto(dtype='DT_FLOAT', float_val=data.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[TensorShapeProto.Dim(size=data.shape[0]), TensorShapeProto.Dim(size=data.shape[1])])) return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
def _add_3d_torch(self, tag, data, step, logdir=None, max_outputs=1, label_to_names=None, description=None): walltime = None if step is None: raise ValueError("Step is not provided or set.") mdata = {} if label_to_names is None else {'label_to_names': label_to_names} summary_metadata = metadata.create_summary_metadata(description=description, metadata=mdata) writer = self._get_file_writer() if logdir is None: logdir = writer.get_logdir() write_dir = PluginDirectory(logdir, metadata.PLUGIN_NAME) geometry_metadata_string = _write_geometry_data(write_dir, tag, step, data, max_outputs) tensor_proto = TensorProto(dtype='DT_STRING', string_val=[geometry_metadata_string], tensor_shape=TensorShapeProto()) writer.add_summary( Summary(value=[ Summary.Value( tag=tag, tensor=tensor_proto, metadata=summary_metadata) ]), step, walltime)
def custom_scalars(layout): categories = [] for k, v in layout.items(): charts = [] for chart_name, chart_meatadata in v.items(): tags = chart_meatadata[1] if chart_meatadata[0] == "Margin": assert len(tags) == 3 mgcc = layout_pb2.MarginChartContent(series=[ layout_pb2.MarginChartContent.Series( value=tags[0], lower=tags[1], upper=tags[2]) ]) chart = layout_pb2.Chart(title=chart_name, margin=mgcc) else: mlcc = layout_pb2.MultilineChartContent(tag=tags) chart = layout_pb2.Chart(title=chart_name, multiline=mlcc) charts.append(chart) categories.append(layout_pb2.Category(title=k, chart=charts)) layout = layout_pb2.Layout(category=categories) plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars") smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype="DT_STRING", string_val=[layout.SerializeToString()], tensor_shape=TensorShapeProto(), ) return Summary(value=[ Summary.Value( tag="custom_scalars__config__", tensor=tensor, metadata=smd) ])
def parse(graph): nodes_proto = [] nodes = [] import itertools for node in itertools.chain(graph.input, graph.output): nodes_proto.append(node) for node in nodes_proto: print(node.name) shapeproto = TensorShapeProto(dim=[ TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim ]) nodes.append( NodeDef( name=node.name.encode(encoding="utf_8"), op="Variable", input=[], attr={ "dtype": AttrValue(type=node.type.tensor_type.elem_type), "shape": AttrValue(shape=shapeproto), }, )) for node in graph.node: _attr = [] for s in node.attribute: _attr.append(" = ".join([str(f[1]) for f in s.ListFields()])) attr = ", ".join(_attr).encode(encoding="utf_8") print(node.output[0]) nodes.append( NodeDef( name=node.output[0].encode(encoding="utf_8"), op=node.op_type, input=node.input, attr={"parameters": AttrValue(s=attr)}, )) # two pass token replacement, appends opname to object id mapping = {} for node in nodes: mapping[node.name] = node.op + "_" + node.name return GraphDef(node=nodes, versions=VersionDef(producer=22))
def _add_tf_shape(attr_dict, ints): ''' Converts a list of ints to a TensorShapeProto representing the dimensions of a blob/object. Args: attr_dict: Dictionary to update (usually attributes of a Node) ints: List of integers representing dimensions of some object. Returns: None. Modifies attr_dict in-place. ''' shape_proto = TensorShapeProto() for i in ints: dim = TensorShapeProto.Dim() dim.size = i shape_proto.dim.extend([dim]) attr_dict['_output_shapes'].list.shape.extend([shape_proto])
def tensor_shape_proto(outputsize): """Creates an object matching https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto """ return TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in outputsize])
def visualize( model_path: str, log_path: str, input: np.ndarray = None, inp_dict: dict = None, cal_params: bool = True, cal_flops: bool = True, cal_activations: bool = True, logging_to_stdout: bool = True, bar_length_max: int = 20, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. :param cal_params: whether calculate and record params size. :param cal_flops: whether calculate and record op flops. :param cal_activations: whether calculate and record op activations. :param logging_to_stdout: whether print all calculated statistic details. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. """ if log_path: try: from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.step_stats_pb2 import ( AllocatorMemoryUsed, DeviceStepStats, NodeExecStats, StepStats, ) from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboardX import SummaryWriter except ImportError: logger.error( "TensorBoard and TensorboardX are required for visualize.", exc_info=True, ) return enable_receptive_field() graph = Network.load(model_path) graph.reset_batch_size(1) has_input = False if input is not None or inp_dict is not None: has_input = True repl_dict = {} inp_vars = graph.input_vars if inp_dict is not None: assert len(inp_dict) == len( inp_vars ), "Inputs are not sufficient for calculation." for v in inp_vars: new_input = graph.make_const(inp_dict[v.name], name=v.name) repl_dict[v] = new_input else: assert len(inp_vars) == 1, "The graph needs more than one input." inp_var = inp_vars[0] repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) graph.replace_vars(repl_dict=repl_dict) graph._compile() def process_name(name): # nodes that start with point or contain float const will lead to display bug if not re.match(r"^[+-]?\d*\.\d*", name): name = name.replace(".", "/") return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] activations_list = [] total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) for node in tqdm(graph.all_oprs): if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: if len(node.outputs) != 1: logger.warning( "OpNode {} has more than one output and not has 'output_idx' attr.".format( node ) ) node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue( list=AttrValue.ListValue( shape=[ TensorShapeProto( dim=[ TensorShapeProto.Dim(size=d) for d in node_oup.shape ] ) ] ) ), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } if cal_flops: flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") ) flops_stats["name"] = node.name flops_stats["class_name"] = node.type flops_list.append(flops_stats) if cal_activations: acts = get_activation_stats(node_oup.numpy(), has_input=has_input) acts["name"] = node.name acts["class_name"] = node.type activations_list.append(acts) if cal_params: if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") ) param_stats["name"] = node.name params_list.append(param_stats) if log_path: node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, ) ) # summary extra_info = { "#ops": len(graph.all_oprs), "#params": len(params_list), } ( total_flops, total_param_dims, total_param_size, total_act_dims, total_act_size, ) = (0, 0, 0, 0, 0) if cal_params: total_param_dims, total_param_size, params_list = sum_param_stats( params_list, bar_length_max ) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") extra_info["total_param_size"] = sizeof_fmt(total_param_size) if logging_to_stdout: print_param_stats(params_list) if cal_flops: total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if logging_to_stdout: print_op_stats(flops_list) if cal_activations: total_act_dims, total_act_size, activations_list = sum_activations_stats( activations_list, bar_length_max ) extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") extra_info["total_act_size"] = sizeof_fmt(total_act_size) if logging_to_stdout: print_activations_stats(activations_list, has_input=has_input) if cal_flops and cal_params: extra_info["flops/param_size"] = "{:3.3f}".format( total_flops / total_param_size ) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) device = "/device:CPU:0" stepstats = RunMetadata( step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) ) writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) print_summary(**extra_info) return ( total_stats( param_size=total_param_size, flops=total_flops, act_size=total_act_size, ), stats_details( params=params_list, flops=flops_list, activations=activations_list ), )
def visualize( model_path: str, log_path: str, bar_length_max: int = 20, log_params: bool = True, log_flops: bool = True, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. """ if log_path: try: from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.step_stats_pb2 import ( AllocatorMemoryUsed, DeviceStepStats, NodeExecStats, StepStats, ) from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboardX import SummaryWriter except ImportError: logger.error( "TensorBoard and TensorboardX are required for visualize.", exc_info=True, ) return # FIXME: remove this after resolving "span dist too large" warning old_level = set_mgb_log_level(logging.ERROR) enable_receptive_field() graph = Network.load(model_path) def process_name(name): # nodes that start with point or contain float const will lead to display bug if not re.match(r"^[+-]?\d*\.\d*", name): name = name.replace(".", "/") return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] for node in graph.all_oprs: if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: if len(node.outputs) != 1: logger.warning( "OpNode {} has more than one output and not has 'output_idx' attr." .format(node)) node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue(list=AttrValue.ListValue(shape=[ TensorShapeProto(dim=[ TensorShapeProto.Dim(size=d) for d in node_oup.shape ]) ])), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( s=sizeof_fmt(flops_stats["flops"]).encode( encoding="utf-8")) flops_stats["name"] = node.name flops_stats["class_name"] = node.type flops_list.append(flops_stats) if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")) param_stats["name"] = node.name params_list.append(param_stats) if log_path: node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, )) # summary extra_info = { "#ops": len(graph.all_oprs), "#params": len(params_list), } total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: total_param_dims, total_param_size = print_param_stats( params_list, bar_length_max) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: total_flops = print_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_params and log_flops: extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) device = "/device:CPU:0" stepstats = RunMetadata(step_stats=StepStats( dev_stats=[DeviceStepStats(device=device)])) writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) print_summary(**extra_info) # FIXME: remove this after resolving "span dist too large" warning _imperative_rt_logger.set_log_level(old_level) return total_param_size, total_flops