Beispiel #1
0
    def from_graph_def(self, graph_def):
        variables = graph_def.variables
        parameters = graph_def.parameters
        functions = graph_def.functions
        inputs = graph_def.inputs
        nodes = []
        scope = {}

        for n, v in parameters.items():
            shape_proto = TensorShapeProto(
                dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
            node = NodeDef(
                name=n.encode(encoding='utf-8'),
                op='Parameter',
                input=[],
                attr={
                    'shape': AttrValue(shape=shape_proto),
                    'dtype': AttrValue(type=DT_FLOAT)
                }
            )
            nodes.append(node)
            scope[n] = node

        for n, v in inputs.items():
            shape_proto = TensorShapeProto(
                dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
            nodes.append(NodeDef(
                name=n.encode(encoding='utf-8'),
                op='Variable',
                input=[],
                attr={
                    'shape': AttrValue(shape=shape_proto),
                    'dtype': AttrValue(type=DT_FLOAT)
                }
            ))

        for func_name, func in functions.items():
            for o in func['outputs']:
                if o in scope:
                    node = scope[o]
                    node.input.extend([func_name])
                else:
                    if o in variables:
                        v = variables[o]
                        shape_proto = TensorShapeProto(
                            dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
                        node = NodeDef(
                            name=o.encode(encoding='utf-8'),
                            op='Variable',
                            input=[func_name],
                            attr={
                                'shape': AttrValue(shape=shape_proto),
                                'dtype': AttrValue(type=DT_FLOAT)
                            }
                        )
                        nodes.append(node)
            for i in func['inputs']:
                if i in variables:
                    v = variables[i]
                    shape_proto = TensorShapeProto(
                        dim=[TensorShapeProto.Dim(size=d) for d in v.shape])
                    node = NodeDef(
                        name=o.encode(encoding='utf-8'),
                        op='Variable',
                        input=[],
                        attr={
                            'shape': AttrValue(shape=shape_proto),
                            'dtype': AttrValue(type=DT_FLOAT)
                        }
                    )
                    nodes.append(node)
                    scope[o] = node
            nodes.append(NodeDef(
                name=func_name,
                op=func['type'],
                input=func['inputs'],
                attr={"arguments": AttrValue(s='a=1'.encode(encoding='utf-8'))}
            ))

        current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22))
        event = event_pb2.Event(
            graph_def=current_graph.SerializeToString())
        self.file_writer.add_event(event)
Beispiel #2
0
    def from_variable(self, leaf, output_name="output"):
        def parse_variable(v, var_num):
            def add_variable(v, v_idx):
                v_name = parameters.get(v.data, None)
                exist = False
                if not v_name:
                    v_name, exist = get_variable_name(v, v_idx)
                if not exist:
                    shape_proto = TensorShapeProto(
                        dim=[TensorShapeProto.Dim(size=d) for d in v.shape])

                    if v.parent is None:
                        inputs = []
                    else:
                        inputs = [get_func_name(v.parent)]
                    # print("Variable: {}:{}".format(v_name, inputs))
                    nodes.append(NodeDef(
                        name=v_name.encode(encoding='utf-8'),
                        op='Variable',
                        input=inputs,
                        attr={
                            'shape': AttrValue(shape=shape_proto),
                            'dtype': AttrValue(type=DT_FLOAT)
                        }
                    ))
                return v_name

            def get_unique_variable_name(v_name_base):
                v_num = 0
                v_name = v_name_base + str(v_num)
                while v_name in unique_var_names:
                    v_num += 1
                    v_name = v_name_base + str(v_num)
                unique_var_names.add(v_name)
                return v_name

            def get_variable_name(v, v_idx):
                v_name = variables.get(v, None)
                if v_name:
                    return v_name, True
                else:
                    if v.parent is None:
                        v_name_base = "Input"
                        v_name = get_unique_variable_name(v_name_base)
                    elif not nodes:
                        v_name = output_name
                    else:
                        f_name_sections = get_func_name(v.parent).split("/")
                        f_name = f_name_sections[-1]
                        f_scope = f_name_sections[:-1]
                        base_name = "variable<-{}".format(f_name)
                        v_name_base = "/".join(f_scope + [base_name])
                        v_name = get_unique_variable_name(v_name_base)

                    variables[v] = v_name
                    return v_name, False

            def get_func_name(func):
                func_name = func_names.get(func, None)
                if func_name:
                    return func_name
                name_scope = loc_var['name_scope']
                for v in func.inputs:
                    v_name = self.parameters.get(v.data, None)
                    if v_name:
                        name_scope = '/'.join(v_name.split('/')[:-1])
                        break
                if name_scope:
                    func_name_base = '/'.join([name_scope, func.name])
                else:
                    func_name_base = func.name
                func_num = 0
                func_name = func_name_base + str(func_num)
                while func_name in unique_func_names:
                    func_num += 1
                    func_name = func_name_base + str(func_num)
                unique_func_names.add(func_name)
                func_names[func] = func_name
                return func_name

            def add_func(v):
                input_names = []
                for index, v_input in enumerate(v.parent.inputs):
                    v_name = add_variable(v_input, index)
                    input_names.append(v_name)
                # print("Function: {}:{}".format(get_func_name(v.parent), input_names))
                f_name = get_func_name(v.parent)
                if f_name in func_set:
                    return False
                attrs = []
                for k, a in v.parent.info.args.items():
                    attr = "{}={}".format(k, a)
                    attrs.append(attr)
                attr_str = ','.join(attrs).encode(encoding='utf-8')
                nodes.append(NodeDef(
                    name=f_name,
                    op=v.parent.info.type_name,
                    input=input_names,
                    attr={"parameters": AttrValue(s=attr_str)}
                ))
                func_set.add(f_name)
                return True

            name_scope = loc_var['name_scope']
            if not nodes:
                add_variable(v, var_num)
            if v.parent is None:
                add_variable(v, var_num)
            else:
                if not add_func(v):
                    return
                for idx, in_var in enumerate(v.parent.inputs):
                    name_scope_stack.append(name_scope)
                    parse_variable(in_var, idx)
                    name_scope = name_scope_stack.pop()

        nodes = []
        variables = {}
        loc_var = {}
        loc_var['name_scope'] = ''
        name_scope_stack = []
        func_names = {}
        func_set = set()
        unique_func_names = set()
        unique_var_names = set()
        parameters = {v.data: k for k,
                      v in get_parameters(grad_only=False).items()}
        parse_variable(leaf, 0)
        nodes = nodes[::-1]

        current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22))
        event = event_pb2.Event(
            graph_def=current_graph.SerializeToString())
        self.file_writer.add_event(event)
Beispiel #3
0
def graph(node_list):
    graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
    stepstats = RunMetadata(step_stats=StepStats(
        dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
    return graph_def, stepstats