示例#1
0
def attr_value_proto(shape, dtype, attr):
    """Creates a dict of objects matching
    https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
    specifically designed for a NodeDef. The values have been
    reverse engineered from standard TensorBoard logged data.
    """
    attr_proto = {}
    if shape is not None:
        shapeproto = tensor_shape_proto(shape)
        attr_proto["_output_shapes"] = AttrValue(list=AttrValue.ListValue(
            shape=[shapeproto]))
    if dtype is not None:
        attr_proto["dtype"] = AttrValue(s=dtype.encode(encoding="utf-8"))
    if attr is not None:
        for key in attr.keys():
            attr_proto[key] = AttrValue(s=attr[key].encode(encoding="utf-8"))

    return attr_proto
示例#2
0
            def add_variable(v, v_idx):
                v_name = parameters.get(v.data, None)
                exist = False
                if not v_name:
                    v_name, exist = get_variable_name(v, v_idx)
                if not exist:
                    shape_proto = TensorShapeProto(
                        dim=[TensorShapeProto.Dim(size=d) for d in v.shape])

                    if v.parent is None:
                        inputs = []
                    else:
                        inputs = [get_func_name(v.parent)]
                    # print("Variable: {}:{}".format(v_name, inputs))
                    nodes.append(
                        NodeDef(name=v_name.encode(encoding='utf-8'),
                                op='Variable',
                                input=inputs,
                                attr={
                                    'shape': AttrValue(shape=shape_proto),
                                    'dtype': AttrValue(type=DT_FLOAT)
                                }))
                return v_name
示例#3
0
 def add_func(v):
     input_names = []
     for index, v_input in enumerate(v.parent.inputs):
         v_name = add_variable(v_input, index)
         input_names.append(v_name)
     # print("Function: {}:{}".format(get_func_name(v.parent), input_names))
     f_name = get_func_name(v.parent)
     if f_name in func_set:
         return False
     attrs = []
     for k, a in v.parent.info.args.items():
         attr = "{}={}".format(k, a)
         attrs.append(attr)
     attr_str = ','.join(attrs).encode(encoding='utf-8')
     nodes.append(
         NodeDef(name=f_name,
                 op=v.parent.info.type_name,
                 input=input_names,
                 attr={"parameters": AttrValue(s=attr_str)}))
     func_set.add(f_name)
     return True
示例#4
0
    def from_graph_def(self, graph_def):
        variables = graph_def.variables
        parameters = graph_def.parameters
        functions = graph_def.functions
        inputs = graph_def.inputs
        nodes = []
        scope = {}

        for n, v in parameters.items():
            shape_proto = TensorShapeProto(
                dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
            node = NodeDef(
                name=n.encode(encoding='utf-8'),
                op='Parameter',
                input=[],
                attr={
                    'shape': AttrValue(shape=shape_proto),
                    'dtype': AttrValue(type=DT_FLOAT)
                }
            )
            nodes.append(node)
            scope[n] = node

        for n, v in inputs.items():
            shape_proto = TensorShapeProto(
                dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
            nodes.append(NodeDef(
                name=n.encode(encoding='utf-8'),
                op='Variable',
                input=[],
                attr={
                    'shape': AttrValue(shape=shape_proto),
                    'dtype': AttrValue(type=DT_FLOAT)
                }
            ))

        for func_name, func in functions.items():
            for o in func['outputs']:
                if o in scope:
                    node = scope[o]
                    node.input.extend([func_name])
                else:
                    if o in variables:
                        v = variables[o]
                        shape_proto = TensorShapeProto(
                            dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
                        node = NodeDef(
                            name=o.encode(encoding='utf-8'),
                            op='Variable',
                            input=[func_name],
                            attr={
                                'shape': AttrValue(shape=shape_proto),
                                'dtype': AttrValue(type=DT_FLOAT)
                            }
                        )
                        nodes.append(node)
            for i in func['inputs']:
                if i in variables:
                    v = variables[i]
                    shape_proto = TensorShapeProto(
                        dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
                    node = NodeDef(
                        name=o.encode(encoding='utf-8'),
                        op='Variable',
                        input=[],
                        attr={
                            'shape': AttrValue(shape=shape_proto),
                            'dtype': AttrValue(type=DT_FLOAT)
                        }
                    )
                    nodes.append(node)
                    scope[o] = node
            nodes.append(NodeDef(
                name=func_name,
                op=func['type'],
                input=func['inputs'],
                attr={"arguments": AttrValue(s='a=1'.encode(encoding='utf-8'))}
            ))

        current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22))
        event = event_pb2.Event(
            graph_def=current_graph.SerializeToString())
        self.file_writer.add_event(event)