コード例 #1
0
ファイル: import-caffe.py プロジェクト: tinyloop/matconvnet
    match = r.match(line)
    synsets_wnid.append(match.group('wnid'))
    synsets_name.append(match.group('name'))

if args.class_names:
  synsets_wnid=list(make_tuple(args.class_names))
  synsets_name=synsets_wnid

# --------------------------------------------------------------------
#                                                          Load layers
# --------------------------------------------------------------------

# Caffe stores the network structure and data into two different files
# We load them both and merge them into a single MATLAB structure

net=caffe_pb2.NetParameter()
data=caffe_pb2.NetParameter()

print 'Loading Caffe CNN structure from {}'.format(args.caffe_proto.name)
google.protobuf.text_format.Merge(args.caffe_proto.read(), net)

if args.caffe_data:
  print 'Loading Caffe CNN parameters from {}'.format(args.caffe_data.name)
  data.MergeFromString(args.caffe_data.read())

# --------------------------------------------------------------------
#                                   Read layers in a CaffeModel object
# --------------------------------------------------------------------

if args.caffe_variant in ['caffe_b590f1d']:
  layers_list = net.layer
コード例 #2
0
if args.synsets:
    print 'Loading synsets from {}'.format(args.synsets.name)
    r = re.compile('(?P<wnid>n[0-9]{8}?) (?P<name>.*)')
    synsets_wnid = []
    synsets_name = []
    for line in args.synsets:
        match = r.match(line)
        synsets_wnid.append(match.group('wnid'))
        synsets_name.append(match.group('name'))

# --------------------------------------------------------------------
#                                                          Load layers
# --------------------------------------------------------------------

print 'Loading Caffe CNN parameters from {}'.format(args.caffe_param.name)
net_param = caffe_pb2.NetParameter()
google.protobuf.text_format.Merge(args.caffe_param.read(), net_param)

print 'Loading Caffe CNN data from {}'.format(args.caffe_data.name)
net_data = caffe_pb2.NetParameter()
net_data.MergeFromString(args.caffe_data.read())

# --------------------------------------------------------------------
#                                                       Convert layers
# --------------------------------------------------------------------

if args.caffe_variant in ['vgg-caffe', 'caffe-old']:
    layers_name_param = [x.layer.name for x in net_param.layers]
    layers_name_data = [x.layer.name for x in net_data.layers]
else:
    layers_name_param = [x.name for x in net_param.layers]