Beispiel #1
0
def attr_value_proto(dtype, shape, s):
    """Creates a dict of objects matching
    https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
    specifically designed for a NodeDef. The values have been
    reverse engineered from standard TensorBoard logged data.
    """
    attr = {}
    if s is not None:
        attr['attr'] = AttrValue(s=s.encode(encoding='utf_8'))
    if shape is not None:
        shapeproto = tensor_shape_proto(shape)
        attr['_output_shapes'] = AttrValue(list=AttrValue.ListValue(
            shape=[shapeproto]))
    return attr
Beispiel #2
0
def parse(graph):
    nodes_proto = []
    nodes = []
    import itertools

    for node in itertools.chain(graph.input, graph.output):
        nodes_proto.append(node)

    for node in nodes_proto:
        print(node.name)
        shapeproto = TensorShapeProto(dim=[
            TensorShapeProto.Dim(size=d.dim_value)
            for d in node.type.tensor_type.shape.dim
        ])
        nodes.append(
            NodeDef(
                name=node.name.encode(encoding="utf_8"),
                op="Variable",
                input=[],
                attr={
                    "dtype": AttrValue(type=node.type.tensor_type.elem_type),
                    "shape": AttrValue(shape=shapeproto),
                },
            ))

    for node in graph.node:
        _attr = []
        for s in node.attribute:
            _attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
        attr = ", ".join(_attr).encode(encoding="utf_8")
        print(node.output[0])
        nodes.append(
            NodeDef(
                name=node.output[0].encode(encoding="utf_8"),
                op=node.op_type,
                input=node.input,
                attr={"parameters": AttrValue(s=attr)},
            ))

    # two pass token replacement, appends opname to object id
    mapping = {}
    for node in nodes:
        mapping[node.name] = node.op + "_" + node.name

    return GraphDef(node=nodes, versions=VersionDef(producer=22))
Beispiel #3
0
def graph(model):
    """Converts a crypten.nn graph for consumption by TensorBoard."""

    # convert individual module to graph:
    assert isinstance(model, nn.Module), "model must be crypten.nn.Module"
    if not isinstance(model, nn.Graph):
        graph = nn.Graph("input", "output")
        graph.add_module("output", model, ["input"])
        model = graph

    # create mapping to more interpretable node naming:
    mapping = {input_name: input_name for input_name in model.input_names}
    modules = {name: module for name, module in model.named_modules()}
    for name, module in modules.items():
        op = str(type(module))[26:-2]
        mapping[name] = "%s_%s" % (op, name)

    # create input variables:
    nodes = [
        NodeDef(
            name=mapping[input_name].encode(encoding="utf_8"),
            op="Variable",
            input=[],
        ) for input_name in model.input_names
    ]

    # loop all graph connections:
    for output_name, input_names in model._graph.items():

        # get parameters and type of module:
        module = modules[output_name]
        op = str(type(module))
        input_names = [mapping[name] for name in input_names]
        parameters = [
            "%s: %s" % (name, parameter.size())
            for name, parameter in module.named_parameters()
        ]
        parameter_string = "; ".join(parameters).encode(encoding="utf_8")

        # add to graph:
        nodes.append(
            NodeDef(
                name=mapping[output_name].encode(encoding="utf_8"),
                op=op,
                input=input_names,
                attr={"attr": AttrValue(s=parameter_string)},
            ))

    # return graph definition:
    return GraphDef(node=nodes, versions=VersionDef(producer=22))
Beispiel #4
0
def visualize(
    model_path: str,
    log_path: str,
    input: np.ndarray = None,
    inp_dict: dict = None,
    cal_params: bool = True,
    cal_flops: bool = True,
    cal_activations: bool = True,
    logging_to_stdout: bool = True,
    bar_length_max: int = 20,
):
    r"""
    Load megengine dumped model and visualize graph structure with tensorboard log files.
    Can also record and print model's statistics like :func:`~.module_stats`

    :param model_path: dir path for megengine dumped model.
    :param log_path: dir path for tensorboard graph log.
    :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input.
    :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used.
    :param cal_params: whether calculate and record params size.
    :param cal_flops: whether calculate and record op flops.
    :param cal_activations: whether calculate and record op activations.
    :param logging_to_stdout: whether print all calculated statistic details.
    :param bar_length_max: size of bar indicating max flops or parameter size in net stats.

    """
    if log_path:
        try:
            from tensorboard.compat.proto.attr_value_pb2 import AttrValue
            from tensorboard.compat.proto.config_pb2 import RunMetadata
            from tensorboard.compat.proto.graph_pb2 import GraphDef
            from tensorboard.compat.proto.node_def_pb2 import NodeDef
            from tensorboard.compat.proto.step_stats_pb2 import (
                AllocatorMemoryUsed,
                DeviceStepStats,
                NodeExecStats,
                StepStats,
            )
            from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
            from tensorboard.compat.proto.versions_pb2 import VersionDef
            from tensorboardX import SummaryWriter
        except ImportError:
            logger.error(
                "TensorBoard and TensorboardX are required for visualize.",
                exc_info=True,
            )
            return

    enable_receptive_field()

    graph = Network.load(model_path)
    graph.reset_batch_size(1)

    has_input = False
    if input is not None or inp_dict is not None:
        has_input = True
        repl_dict = {}
        inp_vars = graph.input_vars
        if inp_dict is not None:
            assert len(inp_dict) == len(
                inp_vars
            ), "Inputs are not sufficient for calculation."
            for v in inp_vars:
                new_input = graph.make_const(inp_dict[v.name], name=v.name)
                repl_dict[v] = new_input
        else:
            assert len(inp_vars) == 1, "The graph needs more than one input."
            inp_var = inp_vars[0]
            repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
        graph.replace_vars(repl_dict=repl_dict)

    graph._compile()

    def process_name(name):
        # nodes that start with point or contain float const will lead to display bug
        if not re.match(r"^[+-]?\d*\.\d*", name):
            name = name.replace(".", "/")
        return name.encode(encoding="utf-8")

    summary = [["item", "value"]]
    node_list = []
    flops_list = []
    params_list = []
    activations_list = []
    total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
    stats_details = namedtuple("module_stats", ["params", "flops", "activations"])

    for node in tqdm(graph.all_oprs):
        if hasattr(node, "output_idx"):
            node_oup = node.outputs[node.output_idx]
        else:
            if len(node.outputs) != 1:
                logger.warning(
                    "OpNode {} has more than one output and not has 'output_idx' attr.".format(
                        node
                    )
                )
            node_oup = node.outputs[0]

        inp_list = [process_name(var.owner.name) for var in node.inputs]
        if log_path:
            # detail format see tensorboard/compat/proto/attr_value.proto
            attr = {
                "_output_shapes": AttrValue(
                    list=AttrValue.ListValue(
                        shape=[
                            TensorShapeProto(
                                dim=[
                                    TensorShapeProto.Dim(size=d) for d in node_oup.shape
                                ]
                            )
                        ]
                    )
                ),
                "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
                "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
            }

        if cal_flops:
            flops_stats = get_op_stats(node, node.inputs, node.outputs)
            if flops_stats is not None:
                # add op flops attr
                if log_path and hasattr(flops_stats, "flops_num"):
                    attr["flops"] = AttrValue(
                        s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
                    )
                flops_stats["name"] = node.name
                flops_stats["class_name"] = node.type
                flops_list.append(flops_stats)

        if cal_activations:
            acts = get_activation_stats(node_oup.numpy(), has_input=has_input)
            acts["name"] = node.name
            acts["class_name"] = node.type
            activations_list.append(acts)

        if cal_params:
            if node.type == "ImmutableTensor":
                param_stats = get_param_stats(node.numpy())
                # add tensor size attr
                if log_path:
                    attr["size"] = AttrValue(
                        s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
                    )
                param_stats["name"] = node.name
                params_list.append(param_stats)

        if log_path:
            node_list.append(
                NodeDef(
                    name=process_name(node.name),
                    op=node.type,
                    input=inp_list,
                    attr=attr,
                )
            )
    # summary
    extra_info = {
        "#ops": len(graph.all_oprs),
        "#params": len(params_list),
    }

    (
        total_flops,
        total_param_dims,
        total_param_size,
        total_act_dims,
        total_act_size,
    ) = (0, 0, 0, 0, 0)

    if cal_params:
        total_param_dims, total_param_size, params_list = sum_param_stats(
            params_list, bar_length_max
        )
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
        if logging_to_stdout:
            print_param_stats(params_list)

    if cal_flops:
        total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
        if logging_to_stdout:
            print_op_stats(flops_list)

    if cal_activations:
        total_act_dims, total_act_size, activations_list = sum_activations_stats(
            activations_list, bar_length_max
        )
        extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
        extra_info["total_act_size"] = sizeof_fmt(total_act_size)
        if logging_to_stdout:
            print_activations_stats(activations_list, has_input=has_input)

    if cal_flops and cal_params:
        extra_info["flops/param_size"] = "{:3.3f}".format(
            total_flops / total_param_size
        )

    if log_path:
        graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

        device = "/device:CPU:0"
        stepstats = RunMetadata(
            step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
        )
        writer = SummaryWriter(log_path)
        writer._get_file_writer().add_graph((graph_def, stepstats))

    print_summary(**extra_info)

    return (
        total_stats(
            param_size=total_param_size, flops=total_flops, act_size=total_act_size,
        ),
        stats_details(
            params=params_list, flops=flops_list, activations=activations_list
        ),
    )
Beispiel #5
0
def node_proto(name, op='UnSpecified', inputs=None, output_shapes=None,
               need_grad=None, info=None):
    """Converts a node to `proto`.

    Args:
        name (str): Name of the node.
        op (str, optional): Name of the operator. Defaults to 'UnSpecified'.
        inputs (list of str, optional): A list of inputs. Defaults to None.
        output_shapes (list, optional): A list of tuple of integers containing the output shapes. Defaults to None.

    Returns:
        proto: A node with `proto` format.
    """
    inputs = inputs or []
    attributes = dict()
    if output_shapes is not None:
        attributes['_output_shapes'] = AttrValue(
            list=AttrValue.ListValue(
                shape=[tensor_shape_proto(o) for o in output_shapes]
            )
        )

    if need_grad is not None:
        attributes['need_grad'] = AttrValue(b=need_grad)

    if info is not None:
        for k, v in info.items():
            if type(v) == bool:
                value = AttrValue(b=v)
            elif type(v) == int:
                value = AttrValue(i=v)
            elif type(v) == float:
                value = AttrValue(f=v)
            elif type(v) == str:
                value = AttrValue(s=v)
            elif type(v) == list:
                if len(v) == 0 or type(v[0]) == int:
                    value = AttrValue(list=AttrValue.ListValue(i=v))
                else:
                    value = AttrValue(list=AttrValue.ListValue(f=v))
            else:
                continue
            attributes[k] = value

    proto = NodeDef(
        name=name.encode(encoding='utf_8'), op=op,
        input=inputs, attr=attributes
    )

    return proto
Beispiel #6
0
def visualize(
    model_path: str,
    log_path: str,
    bar_length_max: int = 20,
    log_params: bool = True,
    log_flops: bool = True,
):
    r"""
    Load megengine dumped model and visualize graph structure with tensorboard log files.
    Can also record and print model's statistics like :func:`~.module_stats`

    :param model_path: dir path for megengine dumped model.
    :param log_path: dir path for tensorboard graph log.
    :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
    :param log_params: whether print and record params size.
    :param log_flops: whether print and record op flops.
    """
    if log_path:
        try:
            from tensorboard.compat.proto.attr_value_pb2 import AttrValue
            from tensorboard.compat.proto.config_pb2 import RunMetadata
            from tensorboard.compat.proto.graph_pb2 import GraphDef
            from tensorboard.compat.proto.node_def_pb2 import NodeDef
            from tensorboard.compat.proto.step_stats_pb2 import (
                AllocatorMemoryUsed,
                DeviceStepStats,
                NodeExecStats,
                StepStats,
            )
            from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
            from tensorboard.compat.proto.versions_pb2 import VersionDef
            from tensorboardX import SummaryWriter
        except ImportError:
            logger.error(
                "TensorBoard and TensorboardX are required for visualize.",
                exc_info=True,
            )
            return
    # FIXME: remove this after resolving "span dist too large" warning
    old_level = set_mgb_log_level(logging.ERROR)

    enable_receptive_field()

    graph = Network.load(model_path)

    def process_name(name):
        # nodes that start with point or contain float const will lead to display bug
        if not re.match(r"^[+-]?\d*\.\d*", name):
            name = name.replace(".", "/")
        return name.encode(encoding="utf-8")

    summary = [["item", "value"]]
    node_list = []
    flops_list = []
    params_list = []
    for node in graph.all_oprs:
        if hasattr(node, "output_idx"):
            node_oup = node.outputs[node.output_idx]
        else:
            if len(node.outputs) != 1:
                logger.warning(
                    "OpNode {} has more than one output and not has 'output_idx' attr."
                    .format(node))
            node_oup = node.outputs[0]

        inp_list = [process_name(var.owner.name) for var in node.inputs]
        if log_path:
            # detail format see tensorboard/compat/proto/attr_value.proto
            attr = {
                "_output_shapes":
                AttrValue(list=AttrValue.ListValue(shape=[
                    TensorShapeProto(dim=[
                        TensorShapeProto.Dim(size=d) for d in node_oup.shape
                    ])
                ])),
                "params":
                AttrValue(s=str(node.params).encode(encoding="utf-8")),
                "dtype":
                AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
            }
        flops_stats = get_op_stats(node, node.inputs, node.outputs)
        if flops_stats is not None:
            # add op flops attr
            if log_path and hasattr(flops_stats, "flops_num"):
                attr["flops"] = AttrValue(
                    s=sizeof_fmt(flops_stats["flops"]).encode(
                        encoding="utf-8"))
            flops_stats["name"] = node.name
            flops_stats["class_name"] = node.type
            flops_list.append(flops_stats)

        if node.type == "ImmutableTensor":
            param_stats = get_param_stats(node.numpy())
            # add tensor size attr
            if log_path:
                attr["size"] = AttrValue(
                    s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8"))
            param_stats["name"] = node.name
            params_list.append(param_stats)

        if log_path:
            node_list.append(
                NodeDef(
                    name=process_name(node.name),
                    op=node.type,
                    input=inp_list,
                    attr=attr,
                ))
    # summary
    extra_info = {
        "#ops": len(graph.all_oprs),
        "#params": len(params_list),
    }

    total_flops, total_param_dims, total_param_size = 0, 0, 0
    if log_params:
        total_param_dims, total_param_size = print_param_stats(
            params_list, bar_length_max)
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
    if log_flops:
        total_flops = print_op_stats(flops_list, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
    if log_params and log_flops:
        extra_info["flops/param_size"] = "{:3.3f}".format(total_flops /
                                                          total_param_size)

    if log_path:
        graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

        device = "/device:CPU:0"
        stepstats = RunMetadata(step_stats=StepStats(
            dev_stats=[DeviceStepStats(device=device)]))
        writer = SummaryWriter(log_path)
        writer._get_file_writer().add_graph((graph_def, stepstats))

    print_summary(**extra_info)

    # FIXME: remove this after resolving "span dist too large" warning
    _imperative_rt_logger.set_log_level(old_level)

    return total_param_size, total_flops