def attr_value_proto(dtype, shape, s): """Creates a dict of objects matching https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto specifically designed for a NodeDef. The values have been reverse engineered from standard TensorBoard logged data. """ attr = {} if s is not None: attr['attr'] = AttrValue(s=s.encode(encoding='utf_8')) if shape is not None: shapeproto = tensor_shape_proto(shape) attr['_output_shapes'] = AttrValue(list=AttrValue.ListValue( shape=[shapeproto])) return attr
def parse(graph): nodes_proto = [] nodes = [] import itertools for node in itertools.chain(graph.input, graph.output): nodes_proto.append(node) for node in nodes_proto: print(node.name) shapeproto = TensorShapeProto(dim=[ TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim ]) nodes.append( NodeDef( name=node.name.encode(encoding="utf_8"), op="Variable", input=[], attr={ "dtype": AttrValue(type=node.type.tensor_type.elem_type), "shape": AttrValue(shape=shapeproto), }, )) for node in graph.node: _attr = [] for s in node.attribute: _attr.append(" = ".join([str(f[1]) for f in s.ListFields()])) attr = ", ".join(_attr).encode(encoding="utf_8") print(node.output[0]) nodes.append( NodeDef( name=node.output[0].encode(encoding="utf_8"), op=node.op_type, input=node.input, attr={"parameters": AttrValue(s=attr)}, )) # two pass token replacement, appends opname to object id mapping = {} for node in nodes: mapping[node.name] = node.op + "_" + node.name return GraphDef(node=nodes, versions=VersionDef(producer=22))
def graph(model): """Converts a crypten.nn graph for consumption by TensorBoard.""" # convert individual module to graph: assert isinstance(model, nn.Module), "model must be crypten.nn.Module" if not isinstance(model, nn.Graph): graph = nn.Graph("input", "output") graph.add_module("output", model, ["input"]) model = graph # create mapping to more interpretable node naming: mapping = {input_name: input_name for input_name in model.input_names} modules = {name: module for name, module in model.named_modules()} for name, module in modules.items(): op = str(type(module))[26:-2] mapping[name] = "%s_%s" % (op, name) # create input variables: nodes = [ NodeDef( name=mapping[input_name].encode(encoding="utf_8"), op="Variable", input=[], ) for input_name in model.input_names ] # loop all graph connections: for output_name, input_names in model._graph.items(): # get parameters and type of module: module = modules[output_name] op = str(type(module)) input_names = [mapping[name] for name in input_names] parameters = [ "%s: %s" % (name, parameter.size()) for name, parameter in module.named_parameters() ] parameter_string = "; ".join(parameters).encode(encoding="utf_8") # add to graph: nodes.append( NodeDef( name=mapping[output_name].encode(encoding="utf_8"), op=op, input=input_names, attr={"attr": AttrValue(s=parameter_string)}, )) # return graph definition: return GraphDef(node=nodes, versions=VersionDef(producer=22))
def visualize( model_path: str, log_path: str, input: np.ndarray = None, inp_dict: dict = None, cal_params: bool = True, cal_flops: bool = True, cal_activations: bool = True, logging_to_stdout: bool = True, bar_length_max: int = 20, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. :param cal_params: whether calculate and record params size. :param cal_flops: whether calculate and record op flops. :param cal_activations: whether calculate and record op activations. :param logging_to_stdout: whether print all calculated statistic details. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. """ if log_path: try: from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.step_stats_pb2 import ( AllocatorMemoryUsed, DeviceStepStats, NodeExecStats, StepStats, ) from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboardX import SummaryWriter except ImportError: logger.error( "TensorBoard and TensorboardX are required for visualize.", exc_info=True, ) return enable_receptive_field() graph = Network.load(model_path) graph.reset_batch_size(1) has_input = False if input is not None or inp_dict is not None: has_input = True repl_dict = {} inp_vars = graph.input_vars if inp_dict is not None: assert len(inp_dict) == len( inp_vars ), "Inputs are not sufficient for calculation." for v in inp_vars: new_input = graph.make_const(inp_dict[v.name], name=v.name) repl_dict[v] = new_input else: assert len(inp_vars) == 1, "The graph needs more than one input." inp_var = inp_vars[0] repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) graph.replace_vars(repl_dict=repl_dict) graph._compile() def process_name(name): # nodes that start with point or contain float const will lead to display bug if not re.match(r"^[+-]?\d*\.\d*", name): name = name.replace(".", "/") return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] activations_list = [] total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) for node in tqdm(graph.all_oprs): if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: if len(node.outputs) != 1: logger.warning( "OpNode {} has more than one output and not has 'output_idx' attr.".format( node ) ) node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue( list=AttrValue.ListValue( shape=[ TensorShapeProto( dim=[ TensorShapeProto.Dim(size=d) for d in node_oup.shape ] ) ] ) ), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } if cal_flops: flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") ) flops_stats["name"] = node.name flops_stats["class_name"] = node.type flops_list.append(flops_stats) if cal_activations: acts = get_activation_stats(node_oup.numpy(), has_input=has_input) acts["name"] = node.name acts["class_name"] = node.type activations_list.append(acts) if cal_params: if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") ) param_stats["name"] = node.name params_list.append(param_stats) if log_path: node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, ) ) # summary extra_info = { "#ops": len(graph.all_oprs), "#params": len(params_list), } ( total_flops, total_param_dims, total_param_size, total_act_dims, total_act_size, ) = (0, 0, 0, 0, 0) if cal_params: total_param_dims, total_param_size, params_list = sum_param_stats( params_list, bar_length_max ) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") extra_info["total_param_size"] = sizeof_fmt(total_param_size) if logging_to_stdout: print_param_stats(params_list) if cal_flops: total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if logging_to_stdout: print_op_stats(flops_list) if cal_activations: total_act_dims, total_act_size, activations_list = sum_activations_stats( activations_list, bar_length_max ) extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") extra_info["total_act_size"] = sizeof_fmt(total_act_size) if logging_to_stdout: print_activations_stats(activations_list, has_input=has_input) if cal_flops and cal_params: extra_info["flops/param_size"] = "{:3.3f}".format( total_flops / total_param_size ) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) device = "/device:CPU:0" stepstats = RunMetadata( step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) ) writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) print_summary(**extra_info) return ( total_stats( param_size=total_param_size, flops=total_flops, act_size=total_act_size, ), stats_details( params=params_list, flops=flops_list, activations=activations_list ), )
def node_proto(name, op='UnSpecified', inputs=None, output_shapes=None, need_grad=None, info=None): """Converts a node to `proto`. Args: name (str): Name of the node. op (str, optional): Name of the operator. Defaults to 'UnSpecified'. inputs (list of str, optional): A list of inputs. Defaults to None. output_shapes (list, optional): A list of tuple of integers containing the output shapes. Defaults to None. Returns: proto: A node with `proto` format. """ inputs = inputs or [] attributes = dict() if output_shapes is not None: attributes['_output_shapes'] = AttrValue( list=AttrValue.ListValue( shape=[tensor_shape_proto(o) for o in output_shapes] ) ) if need_grad is not None: attributes['need_grad'] = AttrValue(b=need_grad) if info is not None: for k, v in info.items(): if type(v) == bool: value = AttrValue(b=v) elif type(v) == int: value = AttrValue(i=v) elif type(v) == float: value = AttrValue(f=v) elif type(v) == str: value = AttrValue(s=v) elif type(v) == list: if len(v) == 0 or type(v[0]) == int: value = AttrValue(list=AttrValue.ListValue(i=v)) else: value = AttrValue(list=AttrValue.ListValue(f=v)) else: continue attributes[k] = value proto = NodeDef( name=name.encode(encoding='utf_8'), op=op, input=inputs, attr=attributes ) return proto
def visualize( model_path: str, log_path: str, bar_length_max: int = 20, log_params: bool = True, log_flops: bool = True, ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. :param bar_length_max: size of bar indicating max flops or parameter size in net stats. :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. """ if log_path: try: from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.config_pb2 import RunMetadata from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.step_stats_pb2 import ( AllocatorMemoryUsed, DeviceStepStats, NodeExecStats, StepStats, ) from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto from tensorboard.compat.proto.versions_pb2 import VersionDef from tensorboardX import SummaryWriter except ImportError: logger.error( "TensorBoard and TensorboardX are required for visualize.", exc_info=True, ) return # FIXME: remove this after resolving "span dist too large" warning old_level = set_mgb_log_level(logging.ERROR) enable_receptive_field() graph = Network.load(model_path) def process_name(name): # nodes that start with point or contain float const will lead to display bug if not re.match(r"^[+-]?\d*\.\d*", name): name = name.replace(".", "/") return name.encode(encoding="utf-8") summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] for node in graph.all_oprs: if hasattr(node, "output_idx"): node_oup = node.outputs[node.output_idx] else: if len(node.outputs) != 1: logger.warning( "OpNode {} has more than one output and not has 'output_idx' attr." .format(node)) node_oup = node.outputs[0] inp_list = [process_name(var.owner.name) for var in node.inputs] if log_path: # detail format see tensorboard/compat/proto/attr_value.proto attr = { "_output_shapes": AttrValue(list=AttrValue.ListValue(shape=[ TensorShapeProto(dim=[ TensorShapeProto.Dim(size=d) for d in node_oup.shape ]) ])), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): attr["flops"] = AttrValue( s=sizeof_fmt(flops_stats["flops"]).encode( encoding="utf-8")) flops_stats["name"] = node.name flops_stats["class_name"] = node.type flops_list.append(flops_stats) if node.type == "ImmutableTensor": param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")) param_stats["name"] = node.name params_list.append(param_stats) if log_path: node_list.append( NodeDef( name=process_name(node.name), op=node.type, input=inp_list, attr=attr, )) # summary extra_info = { "#ops": len(graph.all_oprs), "#params": len(params_list), } total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: total_param_dims, total_param_size = print_param_stats( params_list, bar_length_max) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: total_flops = print_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_params and log_flops: extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) if log_path: graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) device = "/device:CPU:0" stepstats = RunMetadata(step_stats=StepStats( dev_stats=[DeviceStepStats(device=device)])) writer = SummaryWriter(log_path) writer._get_file_writer().add_graph((graph_def, stepstats)) print_summary(**extra_info) # FIXME: remove this after resolving "span dist too large" warning _imperative_rt_logger.set_log_level(old_level) return total_param_size, total_flops