コード例 #1
0
ファイル: graph.py プロジェクト: brianhhu/MMdnn
 def compute_output_shapes(self, model):
     sorted_nodes = self.topologically_sorted()
     (tmp_handle, tmp_prototxt) = tempfile.mkstemp(suffix=".prototxt")
     with open(tmp_prototxt, 'w') as f:
         f.write(text_format.MessageToString(model))
     self.prototxt = tmp_prototxt
     if has_pycaffe():
         caffe = get_caffe_resolver().caffe
         net = caffe.Net(tmp_prototxt, caffe.TEST)
         for key, value in net.blobs.items():
             try:
                 node = self.get_node(key)
                 dims = list(value.shape)
                 dims = dims + [1] * (4 - len(dims))
                 node.output_shape = TensorShape(*dims)
             except:
                 continue
         for node in sorted_nodes:
             if node.output_shape is None:
                 node.output_shape = TensorShape(
                     *NodeKind.compute_output_shape(node))
         os.close(tmp_handle)
     else:
         for node in sorted_nodes:
             node.output_shape = TensorShape(
                 *NodeKind.compute_output_shape(node))
コード例 #2
0
ファイル: graph.py プロジェクト: skybigzhou/MMdnn
 def compute_output_shapes(self, model):
     sorted_nodes = self.topologically_sorted()
     (tmp_handle, tmp_prototxt) = tempfile.mkstemp(suffix=".prototxt")
     with open(tmp_prototxt, 'w') as f:
         f.write(text_format.MessageToString(model))
     self.prototxt = tmp_prototxt
     if has_pycaffe():
         caffe = get_caffe_resolver().caffe
         net = caffe.Net(tmp_prototxt, caffe.TEST)
         for key, value in net.blobs.items():
             try:
                 node = self.get_node(key)
                 dims = list(value.shape)
                 dims = dims + [1] * (4 - len(dims))
                 node.output_shape = TensorShape(*dims)
             except:
                 continue
         for node in sorted_nodes:
             if node.output_shape is None:
                 node.output_shape = TensorShape(*NodeKind.compute_output_shape(node))
         os.close(tmp_handle)
         os.remove(tmp_prototxt)
     else:
         for node in sorted_nodes:
             node.output_shape = TensorShape(*NodeKind.compute_output_shape(node))
コード例 #3
0
ファイル: transformer.py プロジェクト: Wang-Yujue/MMdnn
 def __init__(self, def_path, data_path):
     # The .prototxt file defining the graph
     self.def_path = def_path
     # The .caffemodel file containing the learned parameters
     self.data_path = data_path
     # Set to true if the fallback protocol-buffer based backend was used
     self.did_use_pb = False
     # A list containing (layer name, parameters) tuples
     self.params = None
     # Load the parameters
     self.caffemodel = None
     if has_pycaffe() and self.def_path:
         self.load_using_caffe()
     else:
         self.load_using_pb()
コード例 #4
0
ファイル: transformer.py プロジェクト: zouxinhua/MMdnn
 def __init__(self, def_path, data_path):
     # The .prototxt file defining the graph
     self.def_path = def_path
     # The .caffemodel file containing the learned parameters
     self.data_path = data_path
     # Set to true if the fallback protocol-buffer based backend was used
     self.did_use_pb = False
     # A list containing (layer name, parameters) tuples
     self.params = None
     # Load the parameters
     self.caffemodel = None
     if has_pycaffe() and self.def_path:
         self.load_using_caffe()
     else:
         self.load_using_pb()