コード例 #1
0
ファイル: weights.py プロジェクト: anglebinbin/Barista-tool
def loadNetParamFromString(paramstring):
    from backend.caffe.path_loader import PathLoader
    proto = PathLoader().importProto()
    net = proto.NetParameter()
    try:
        net.ParseFromString(paramstring)
        return net
    except:
        pass
コード例 #2
0
ファイル: weights.py プロジェクト: anglebinbin/Barista-tool
def loadNetParameter(caffemodel):
    """ Return a NetParameter protocol buffer loaded from the caffemodel.
    """
    from backend.caffe.path_loader import PathLoader
    proto = PathLoader().importProto()
    net = proto.NetParameter()

    try:
        with open(caffemodel, 'rb') as f:
            net.ParseFromString(f.read())
            return net
    except:
        pass
コード例 #3
0
def bareNet(name):
    """ Creates a dictionary of a networks with default values where required. """
    from backend.caffe.path_loader import PathLoader
    proto = PathLoader().importProto()
    net = proto.NetParameter()
    descr = info.ParameterGroupDescriptor(net)
    params = descr.parameter().copy()
    del params["layer"]
    del params["layers"]
    res = _extract_param(net, params)
    res["layers"] = {}
    res["layerOrder"] = []
    res["name"] = unicode(name)
    return res
コード例 #4
0
def _import_dictionary(netdict):
    """fill the ProtoTxt-Net with data from the dictionary"""
    from backend.caffe.path_loader import PathLoader
    proto = PathLoader().importProto()
    net = proto.NetParameter()

    for entry in netdict:
        if entry == "layerOrder":
            continue
        if entry == "layers":
            _extract_layer(netdict["layers"], netdict["layerOrder"], net)
            continue
        _insert(entry, netdict[entry], net)
    return net
コード例 #5
0
ファイル: loader.py プロジェクト: anglebinbin/Barista-tool
def loadNet(netstring):
    """ Load the prototxt string "netstring" into a dictionary.
        The dictionary has the following form


        {
            "name": "Somenetwork",
            "input_dim": [1,2,1,1],
            "state": {
                   "phase": "TRAIN"
           },
             ...
            "layers":
            {
                "somerandomid1": {
                    "type": LayerType Instance of Pooling-Layer,
                    "parameters": {
                        "pooling_param": [
                            "kernel_size": 23,
                            "engine": "DEFAULT"
                        ]
                        ....
                        "input_param": [
                            {"shape": {"dim": [...], ....  },
                            {"shape": {"dim": [...], ....  },
                        ]
                    }
                },
              "somerandomid2": {"type": ..., "parameters": ....}
            },
           "layerOrder": ["somerandomid1", "somerandomid2", ....]
        }

    """
    from backend.caffe.path_loader import PathLoader
    proto = PathLoader().importProto()
    # Load Protoclass for parsing
    net = proto.NetParameter()

    # Get DESCRIPTION for meta infos
    descr = info.ParameterGroupDescriptor(net)
    # "Parse" the netdefinition in prototxt-format
    try:
        text_format.Merge(netstring, net)
    except ParseError as ex:
        raise ParseException(str(ex))
    params = descr.parameter().copy()  # All Parameters of the network

    # add logger output if deprecated layers have been found, to inform the user that those can't be parsed yet
    if len(net.layers) > 0:
        callerId = Log.getCallerId('protoxt-parser')
        Log.log(
            "The given network contains deprecated layer definitions which are not supported and will be dropped.",
            callerId)

    # Layers is deprecated, Layer will be handled seperatly and linked to "Layers" key
    del params["layers"]
    del params["layer"]
    if params.has_key("layerOrder"):
        raise ValueError('Key layerOrder not expected!')

    # Extract every other parameters
    res = _extract_param(net, params)

    res["layers"], res["layerOrder"] = _load_layers(net.layer)

    res = copy.deepcopy(res)
    return res