def __init__(self, net=None): if net is not None: if isinstance(net, core.Net): self._NNModule = C.NNModuleFromProtobuf( net.Proto().SerializeToString()) elif isinstance(net, caffe2_pb2.NetDef): self._NNModule = C.NNModuleFromProtobuf( net.SerializeToString()) else: raise Exception( "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types" ) else: self._NNModule = C.NNModule()
def __init__(self, net=None, device_map=None): if net is not None: serialized_proto = None if isinstance(net, core.Net): serialized_proto = net.Proto().SerializeToString() elif isinstance(net, caffe2_pb2.NetDef): serialized_proto = net.SerializeToString() # Distributed if device_map is not None: serialized_device_map = {} for k in device_map: serialized_device_map[k] = device_map[k].SerializeToString( ) self._NNModule = C.NNModuleFromProtobufDistributed( serialized_proto, serialized_device_map) # Default elif serialized_proto: self._NNModule, self._OpList = C.NNModuleFromProtobuf( serialized_proto) else: raise Exception( "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types" ) else: self._NNModule = C.NNModule()