def set_parameters(self, params, verbose=0, create_blobs=False): for layer_name, layer in self.layers: if layer_name in params: layer_blobs = layer.blobs() if create_blobs and len(layer_blobs) == 0: for i, npdata in enumerate(params[layer_name]): assert npdata.ndim == 4, 'error: paramter is not a blob' blob = gpudm.BlobFloat(*npdata.shape) blob.mutable_to_numpy_ref()[:] = npdata layer_blobs.push_back(blob) elif len(layer_blobs) > 0: assert len(params[layer_name]) == len( layer_blobs ), "expected %d blobs for layer %s, received %d" % ( len(layer_blobs), layer_name, len(params[layer_name])) for i, npdata in enumerate(params[layer_name]): #print 'Setting param #%d in layer %s' % (i, layer_name) if type(npdata) == np.ndarray: blob = layer_blobs[i].mutable_to_numpy_ref() assert blob.shape == npdata.shape, "Error: parameters shapes differ: blob=%s vs weights=%s" % ( str(blob.shape), str(npdata.shape)) blob[:] = npdata else: set_blob_to(layer_blobs[i], npdata) elif verbose: print "Layer %s not found in parameters" % (layer_name)
def __init__(self, input_blob): if type(input_blob) == tuple: # just a size size = input_blob input_blob = gpudm.BlobFloat() input_blob.Reshape(size[0], size[1], size[2], size[3]) self.input_blob = input_blob self.activation_blobs = [("data", input_blob)] self.layers = []
def add_layer(self, name, layer_class, args, inplace=False): self.check_name(name) # single output if inplace: top_blob = self.activation_blobs[-1][1] if hasattr(top_blob,'ShareData'): blob = gpudm.BlobFloat(*blob_shape(top_blob)) blob.ShareData(top_blob) blob.ShareDiff(top_blob) top_blob = blob else: top_blob = gpudm.BlobFloat() if type(args) is not tuple: args = (args,) layer = layer_class(*args) self.layers.append((name, layer)) layer.SetUp(self.new_BlobPtrVector(self.activation_blobs[-1][1]), self.new_BlobPtrVector(top_blob)) self.activation_blobs.append((name, top_blob))
def arrToBlob(arr): # dirty function but simpler for now bb = gpudm.BlobFloat(1,1,1,arr.size) bb.mutable_to_numpy_ref().view(arr.dtype)[:] = arr.ravel() return bb