def from_graph_def(self, graph_def): variables = graph_def.variables parameters = graph_def.parameters functions = graph_def.functions inputs = graph_def.inputs nodes = [] scope = {} for n, v in parameters.items(): shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) node = NodeDef( name=n.encode(encoding='utf-8'), op='Parameter', input=[], attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } ) nodes.append(node) scope[n] = node for n, v in inputs.items(): shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) nodes.append(NodeDef( name=n.encode(encoding='utf-8'), op='Variable', input=[], attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } )) for func_name, func in functions.items(): for o in func['outputs']: if o in scope: node = scope[o] node.input.extend([func_name]) else: if o in variables: v = variables[o] shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) node = NodeDef( name=o.encode(encoding='utf-8'), op='Variable', input=[func_name], attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } ) nodes.append(node) for i in func['inputs']: if i in variables: v = variables[i] shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) node = NodeDef( name=o.encode(encoding='utf-8'), op='Variable', input=[], attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } ) nodes.append(node) scope[o] = node nodes.append(NodeDef( name=func_name, op=func['type'], input=func['inputs'], attr={"arguments": AttrValue(s='a=1'.encode(encoding='utf-8'))} )) current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22)) event = event_pb2.Event( graph_def=current_graph.SerializeToString()) self.file_writer.add_event(event)
def from_variable(self, leaf, output_name="output"): def parse_variable(v, var_num): def add_variable(v, v_idx): v_name = parameters.get(v.data, None) exist = False if not v_name: v_name, exist = get_variable_name(v, v_idx) if not exist: shape_proto = TensorShapeProto( dim=[TensorShapeProto.Dim(size=d) for d in v.shape]) if v.parent is None: inputs = [] else: inputs = [get_func_name(v.parent)] # print("Variable: {}:{}".format(v_name, inputs)) nodes.append(NodeDef( name=v_name.encode(encoding='utf-8'), op='Variable', input=inputs, attr={ 'shape': AttrValue(shape=shape_proto), 'dtype': AttrValue(type=DT_FLOAT) } )) return v_name def get_unique_variable_name(v_name_base): v_num = 0 v_name = v_name_base + str(v_num) while v_name in unique_var_names: v_num += 1 v_name = v_name_base + str(v_num) unique_var_names.add(v_name) return v_name def get_variable_name(v, v_idx): v_name = variables.get(v, None) if v_name: return v_name, True else: if v.parent is None: v_name_base = "Input" v_name = get_unique_variable_name(v_name_base) elif not nodes: v_name = output_name else: f_name_sections = get_func_name(v.parent).split("/") f_name = f_name_sections[-1] f_scope = f_name_sections[:-1] base_name = "variable<-{}".format(f_name) v_name_base = "/".join(f_scope + [base_name]) v_name = get_unique_variable_name(v_name_base) variables[v] = v_name return v_name, False def get_func_name(func): func_name = func_names.get(func, None) if func_name: return func_name name_scope = loc_var['name_scope'] for v in func.inputs: v_name = self.parameters.get(v.data, None) if v_name: name_scope = '/'.join(v_name.split('/')[:-1]) break if name_scope: func_name_base = '/'.join([name_scope, func.name]) else: func_name_base = func.name func_num = 0 func_name = func_name_base + str(func_num) while func_name in unique_func_names: func_num += 1 func_name = func_name_base + str(func_num) unique_func_names.add(func_name) func_names[func] = func_name return func_name def add_func(v): input_names = [] for index, v_input in enumerate(v.parent.inputs): v_name = add_variable(v_input, index) input_names.append(v_name) # print("Function: {}:{}".format(get_func_name(v.parent), input_names)) f_name = get_func_name(v.parent) if f_name in func_set: return False attrs = [] for k, a in v.parent.info.args.items(): attr = "{}={}".format(k, a) attrs.append(attr) attr_str = ','.join(attrs).encode(encoding='utf-8') nodes.append(NodeDef( name=f_name, op=v.parent.info.type_name, input=input_names, attr={"parameters": AttrValue(s=attr_str)} )) func_set.add(f_name) return True name_scope = loc_var['name_scope'] if not nodes: add_variable(v, var_num) if v.parent is None: add_variable(v, var_num) else: if not add_func(v): return for idx, in_var in enumerate(v.parent.inputs): name_scope_stack.append(name_scope) parse_variable(in_var, idx) name_scope = name_scope_stack.pop() nodes = [] variables = {} loc_var = {} loc_var['name_scope'] = '' name_scope_stack = [] func_names = {} func_set = set() unique_func_names = set() unique_var_names = set() parameters = {v.data: k for k, v in get_parameters(grad_only=False).items()} parse_variable(leaf, 0) nodes = nodes[::-1] current_graph = GraphDef(node=nodes, versions=VersionDef(producer=22)) event = event_pb2.Event( graph_def=current_graph.SerializeToString()) self.file_writer.add_event(event)
def graph(node_list): graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) stepstats = RunMetadata(step_stats=StepStats( dev_stats=[DeviceStepStats(device="/device:CPU:0")])) return graph_def, stepstats