Example #1
0
 def __call__(self, graph):
     for node in graph.nodes:
         if node.data is None:
             continue
         if node.kind not in self.reshaped_node_types:
             # Check for 2+ dimensional data
             if any(len(tensor.shape) > 1 for tensor in node.data):
                 print_stderr('Warning: parameters not reshaped for node: {}'.format(node))
             continue
         transpose_order = self.map(node.kind)
         weights = node.data[0]
         if (node.kind == NodeKind.InnerProduct) and self.has_spatial_parent(node):
             # The FC layer connected to the spatial layer needs to be
             # re-wired to match the new spatial ordering.
             in_shape = node.get_only_parent()[0].output_shape
             fc_shape = weights.shape
             output_channels = fc_shape[0]
             weights = weights.reshape((output_channels, in_shape.channels, in_shape.height,
                                        in_shape.width))
             weights = weights.transpose(self.map(NodeKind.Convolution))
             node.reshaped_data = weights.reshape(fc_shape[transpose_order[0]],
                                                  fc_shape[transpose_order[1]])
         else:
             node.reshaped_data = weights.transpose(transpose_order)
         # node.reshaped_data = weights.transpose(transpose_order)
     if self.replace:
         for node in graph.nodes:
             if hasattr(node, 'reshaped_data'):
                 # Set the weights
                 node.data[0] = node.reshaped_data
                 del node.reshaped_data
     return graph
Example #2
0
 def __call__(self, graph):
     for node in graph.nodes:
         if node.data is None:
             continue
         if node.kind not in self.reshaped_node_types:
             # Check for 2+ dimensional data
             if any(len(tensor.shape) > 1 for tensor in node.data):
                 print_stderr(
                     'Warning: parameters not reshaped for node: {}'.format(
                         node))
             continue
         transpose_order = self.map(node.kind)
         weights = node.data[0]
         if (node.kind == NodeKind.InnerProduct
             ) and self.has_spatial_parent(node):
             # The FC layer connected to the spatial layer needs to be
             # re-wired to match the new spatial ordering.
             in_shape = node.get_only_parent()[0].output_shape
             fc_shape = weights.shape
             output_channels = fc_shape[0]
             weights = weights.reshape((output_channels, in_shape.channels,
                                        in_shape.height, in_shape.width))
             weights = weights.transpose(self.map(NodeKind.Convolution))
             node.reshaped_data = weights.reshape(
                 fc_shape[transpose_order[0]], fc_shape[transpose_order[1]])
         else:
             node.reshaped_data = weights.transpose(transpose_order)
         # node.reshaped_data = weights.transpose(transpose_order)
     if self.replace:
         for node in graph.nodes:
             if hasattr(node, 'reshaped_data'):
                 # Set the weights
                 node.data[0] = node.reshaped_data
                 del node.reshaped_data
     return graph
Example #3
0
 def __call__(self, graph):
     for layer_name, data in self.params:
         if layer_name in graph:
             node = graph.get_node(layer_name)
             node.data = self.adjust_parameters(node, data)
         else:
             print_stderr('Ignoring parameters for non-existent layer: %s' % layer_name)
     return graph
Example #4
0
 def __call__(self, graph):
     for layer_name, data in self.params:
         if layer_name in graph:
             node = graph.get_node(layer_name)
             node.data = self.adjust_parameters(node, data)
         else:
             print_stderr('Ignoring parameters for non-existent layer: %s' %
                          layer_name)
     return graph
Example #5
0
 def __call__(self, graph):
     for node in graph.nodes:
         if node.data is None:
             continue
         if node.kind in (NodeKind.Convolution, NodeKind.Deconvolution, NodeKind.InnerProduct):
             names = ('weights',)
             if node.parameters.bias_term:
                 names += ('bias',)
         elif node.kind == NodeKind.BatchNorm:
             names = ('mean', 'var')
             if len(node.data) == 4:
                 names += ('scale', 'bias')
         else:
             print_stderr('WARNING: Unhandled parameters: {}'.format(node.kind))
             continue
         assert len(names) == len(node.data)
         node.data = dict(zip(names, node.data))
     return graph
Example #6
0
 def __call__(self, graph):
     for node in graph.nodes:
         if node.data is None:
             continue
         if node.kind in (NodeKind.Convolution, NodeKind.InnerProduct):
             names = ('weights', )
             if node.parameters.bias_term:
                 names += ('bias', )
         elif node.kind == NodeKind.BatchNorm:
             names = ('mean', 'var')
             if len(node.data) == 4:
                 names += ('scale', 'bias')
         else:
             print_stderr('WARNING: Unhandled parameters: {}'.format(
                 node.kind))
             continue
         assert len(names) == len(node.data)
         node.data = dict(zip(names, node.data))
     return graph
Example #7
0
 def __init__(self, def_path, data_path, target_toolkit, input_shape=None, phase='test'):
     self.layer_name_map = {}
     self.data_injector = None
     self.is_train_proto = False
     self.input_shape = input_shape
     if def_path is None:
         if self.input_shape is None:
             raise ConversionError('if the graph prototxt is not provided, the input shape should be provided')
         self.input_shape = [1] + self.input_shape
         def_path, self.data_injector = self.gen_prototxt_from_caffemodel(data_path, self.input_shape)
         self.is_train_proto = True
     else:
         model = get_caffe_resolver().NetParameter()
         with open(def_path, 'r') as f:
             text_format.Merge(f.read(), model)
         layers = model.layers or model.layer
         if len([layer for layer in layers if NodeKind.map_raw_kind(layer.type) in LAYER_IN_TRAIN_PROTO]) > 0:
             if self.input_shape is None:
                 raise ConversionError('the train_val.prototxt should be provided with the input shape')
             self.input_shape = [1] + self.input_shape
             self.is_train_proto = True
     graph = GraphBuilder(def_path, self.input_shape, self.is_train_proto, phase).build()
     if self.is_train_proto:
         def_path = graph.prototxt
     if data_path is not None:
         graph = graph.transformed([
             self.data_injector if self.data_injector else DataInjector(def_path, data_path), # Load and associate learned parameters
             BatchNormScaleBiasFuser(),
             BatchNormPreprocessor() # Pre-process batch normalization data
         ])
         target_toolkit = target_toolkit.lower()
         if target_toolkit not in ('caffe', 'caffe2'):
             graph = graph.transformed([DataReshaper({ # Reshape the parameters to TensorFlow's ordering
                 NodeKind.Convolution: (2, 3, 1, 0), # (c_o, c_i, h, w) -> (h, w, c_i, c_o)
                 NodeKind.Deconvolution: (2, 3, 1, 0), # (c_o, c_i, h, w) -> (h, w, c_i, c_o)
                 NodeKind.InnerProduct: (1, 0) # (c_o, c_i) -> (c_i, c_o)
             }),
                 ParameterNamer() # Convert parameters to dictionaries
             ])
     self.graph = graph
     #  self.graph = NodeRenamer()(graph)
     print_stderr(self.graph)
Example #8
0
 def __init__(self,
              def_path,
              data_path,
              target_toolkit,
              input_shape=None,
              phase='test'):
     self.layer_name_map = {}
     self.data_injector = None
     self.is_train_proto = False
     self.input_shape = input_shape
     if def_path is None:
         if self.input_shape is None:
             raise ConversionError(
                 'if the graph prototxt is not provided, the input shape should be provided'
             )
         self.input_shape = [1] + self.input_shape
         def_path, self.data_injector = self.gen_prototxt_from_caffemodel(
             data_path, self.input_shape)
         self.is_train_proto = True
     else:
         model = get_caffe_resolver().NetParameter()
         with open(def_path, 'r') as f:
             text_format.Merge(f.read(), model)
         layers = model.layers or model.layer
         if len([
                 layer for layer in layers if NodeKind.map_raw_kind(
                     layer.type) in LAYER_IN_TRAIN_PROTO
         ]) > 0:
             if self.input_shape is None:
                 raise ConversionError(
                     'the train_val.prototxt should be provided with the input shape'
                 )
             self.input_shape = [1] + self.input_shape
             self.is_train_proto = True
     graph = GraphBuilder(def_path, self.input_shape, self.is_train_proto,
                          phase).build()
     if self.is_train_proto:
         def_path = graph.prototxt
     if data_path is not None:
         graph = graph.transformed([
             self.data_injector if self.data_injector else DataInjector(
                 def_path,
                 data_path),  # Load and associate learned parameters
             BatchNormScaleBiasFuser(),
             BatchNormPreprocessor()  # Pre-process batch normalization data
         ])
         target_toolkit = target_toolkit.lower()
         if target_toolkit not in ('caffe', 'caffe2'):
             graph = graph.transformed([
                 DataReshaper(
                     {  # Reshape the parameters to TensorFlow's ordering
                         NodeKind.Convolution:
                         (2, 3, 1,
                          0),  # (c_o, c_i, h, w) -> (h, w, c_i, c_o)
                         NodeKind.InnerProduct:
                         (1, 0)  # (c_o, c_i) -> (c_i, c_o)
                     }),
                 ParameterNamer()  # Convert parameters to dictionaries
             ])
     self.graph = NodeRenamer()(graph)
     print_stderr(self.graph)