def __init__(self, *args): """Create a ``Net``. Parameters ---------- net_file : str The path of text proto file to load network. param_file : str, optional The path of binary proto file to load parameters. phase : {'TRAIN', 'TEST'}, optional The optional phase tag. """ if len(args) == 2: (net_file, self._phase), param_file = args, None elif len(args) == 3: net_file, param_file, self._phase = args else: raise ValueError('Excepted 2 or 3 args.') self._blobs = {} self._layers = [] self._layer_blobs = [] self._losses = [] self._params = [] self._blob_dict = None self._param_dict = None self._input_list = None self._output_list = None # Parse the network file with open(net_file, 'r') as f: self._proto = text_format.Parse(f.read(), caffe_pb2.NetParameter()) # Construct the layer class from proto for layer_param in self._proto.layer: if not self._filter_layer(layer_param): continue cls = getattr(layer_factory, layer_param.type) with context.name_scope(layer_param.name): self._layers.append(cls(layer_param)) # Prepare for the legacy net inputs if len(self._proto.input) > 0: layer_param = caffe_pb2.LayerParameter( name='data', type='Input', top=self._proto.input, input_param=caffe_pb2.InputParameter( shape=self._proto.input_shape)) cls = getattr(layer_factory, layer_param.type) with context.name_scope(layer_param.name): self._layers.insert(0, cls(layer_param)) # Call layers sequentially to get outputs self._setup() # Collect losses and parameters for layer in self._proto.layer: if not self._filter_layer(layer): continue self._collect_losses_and_params(layer) # Load the pre-trained weights if necessary if param_file is not None: self.copy_from(param_file)
def to_proto(self): names = {v: k for k, v in self.tops.items()} autonames = collections.Counter() layers = collections.OrderedDict() for name, top in self.tops.items(): top._to_proto(layers, names, autonames) net = caffe_pb2.NetParameter() net.layer.extend(layers.values()) return net
def to_proto(*tops): """Generate a NetParameter that contains all layers needed to compute all arguments.""" layers = collections.OrderedDict() autonames = collections.Counter() for top in tops: top.fn._to_proto(layers, {}, autonames) net = caffe_pb2.NetParameter() net.layer.extend(layers.values()) return net
def to_proto(self): """Serialize to the proto. Returns ------- NetParameter The ``NetParameter`` protocol buffer. """ layer_proto = [layer.to_proto() for layer in self._layers] return caffe_pb2.NetParameter(name=self._proto.name, layer=layer_proto)
def __init__(self, network_file, phase='TEST', weights=None): """Create a ``Net``. Parameters ---------- network_file : str The path of text proto file to load network. phase : str, optional, default='TEST' The execution phase. weights : str, optional The path of binary proto file to load weights. """ # Parse the network file. with open(network_file, 'r') as f: self._proto = google.protobuf.text_format.Parse( f.read(), caffe_pb2.NetParameter()) self._phase = phase self._layers = [] self._learnable_blobs = [] self._net_blobs = dict() self._net_outputs = set() # Construct the layers from proto. layer_names = [] for layer_param in self._proto.layer: if not self._filter_layer(layer_param): continue try: layer_index = layer_names.index(layer_param.name) call_layer = self._layers[layer_index] except ValueError: call_layer = None layer_names.append(layer_param.name) cls = getattr(layer_factory, layer_param.type) self._layers.append(cls(layer_param)) self._layers[-1]._call_layer = call_layer # Add an input layer for the legacy inputs. if len(self._proto.input) > 0: layer_param = caffe_pb2.LayerParameter( name='data', type='Input', top=self._proto.input, input_param=caffe_pb2.InputParameter( shape=self._proto.input_shape)) cls = getattr(layer_factory, layer_param.type) with context.name_scope(layer_param.name): self._layers.insert(0, cls(layer_param)) # Connect layers to get outputs. self._init() # Load the pre-trained weights if necessary if weights is not None: self.copy_from(weights)
def copy_from(self, other): """Copy layers from the other. Parameters ---------- other : Union[str, NetParameter] The path of binary proto file or ``NetParameter``. """ if (hasattr(other, 'ParseFromString') and callable(other.ParseFromString)): self.from_proto(other) else: self.from_proto( serialization.deserialize_proto( serialization.load_bytes(other), caffe_pb2.NetParameter()))