コード例 #1
0
 def __init__(self, net=None):
     if net is not None:
         if isinstance(net, core.Net):
             self._NNModule = C.NNModuleFromProtobuf(
                 net.Proto().SerializeToString())
         elif isinstance(net, caffe2_pb2.NetDef):
             self._NNModule = C.NNModuleFromProtobuf(
                 net.SerializeToString())
         else:
             raise Exception(
                 "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
             )
     else:
         self._NNModule = C.NNModule()
コード例 #2
0
ファイル: nomnigraph.py プロジェクト: awthomp/pytorch-dev
    def __init__(self, net=None, device_map=None):
        if net is not None:
            serialized_proto = None
            if isinstance(net, core.Net):
                serialized_proto = net.Proto().SerializeToString()
            elif isinstance(net, caffe2_pb2.NetDef):
                serialized_proto = net.SerializeToString()

            # Distributed
            if device_map is not None:
                serialized_device_map = {}
                for k in device_map:
                    serialized_device_map[k] = device_map[k].SerializeToString(
                    )
                self._NNModule = C.NNModuleFromProtobufDistributed(
                    serialized_proto, serialized_device_map)
            # Default
            elif serialized_proto:
                self._NNModule, self._OpList = C.NNModuleFromProtobuf(
                    serialized_proto)
            else:
                raise Exception(
                    "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
                )
        else:
            self._NNModule = C.NNModule()