Ejemplo n.º 1
0
class DataLoader(object):
    def __init__(self, params, db, nn_db):
	self.lock = Lock()
	self.db = db
	self.cur = db.length
        self.im_shape = params['im_shape']
        self.nn_shape = params['nn_shape']
        self.hist_eq = params['hist_eq']
        self.indexes = np.arange(db.length)
	self.shuffle = params['shuffle']
	self.subtract_mean = params['subtract_mean']
	if self.subtract_mean:
	    self.mean_img = self.db.read_mean_img(self.im_shape)
	
	self.im_shape = params['im_shape']
	self.load_nn = params['load_nn']
	self.nn_query_size = params['nn_query_size']
	if self.load_nn:
	    self.nn_db = nn_db
	    #nn_ignore = 1 if db.db_root == nn_db.db_root else 0
	    nn_ignore = 0
	    self.nn = NN(nn_db, params['nn_db_size'], nn_ignore)
   
    def load_next_data(self):
	nid = self.get_next_id()
        jp, imgs, segs = self.db.read_instance(nid, size=self.im_shape)
        item = {'jp':jp}
	for i in xrange(len(imgs)):
	    img = imgs[i]
	    if self.hist_eq:
		img = correct_hist(img)
	    item.update({'img_' + shape_str(self.im_shape[i]):img.transpose((2,0,1)), 'seg_' + shape_str(self.im_shape[i]): segs[i]})
	if self.load_nn:
	    nn_id = self.nn.nn_ids(jp, self.nn_query_size)
	    if hasattr(nn_id, '__len__'):
		nn_id = random.choice(nn_id)
	    nn_jp, nn_imgs, nn_segs = self.nn_db.read_instance(nn_id, size=self.nn_shape)
	    item.update({'nn_jp':nn_jp})
	    for i in xrange(len(nn_imgs)):
		nn_img = nn_imgs[i]
		if self.hist_eq:
		    nn_img = correct_hist(nn_img)
		item.update({'nn_img_' + shape_str(self.nn_shape[i]):nn_img.transpose((2,0,1)), 'nn_seg_' + shape_str(self.nn_shape[i]): nn_segs[i]})    
        return item

    def get_next_id(self):
	self.lock.acquire()
	if self.cur >= len(self.indexes) - 1:
            self.cur = 0
            if self.shuffle:
		random.shuffle(self.indexes)
	else:
	    self.cur += 1
	self.lock.release()
	return self.indexes[self.cur]
Ejemplo n.º 2
0
class EKFNNLayer(caffe.Layer):
    def setup(self, bottom, top):
        params = eval(self.param_str)
        check_params(params,  
                nn_root=None,
                nn_shape=None,
                nn_query_size=1,
                nn_num=1,
                nn_db_size=np.inf,
                nn_ignore=1)
	
	self.params = Map(params)
	self.nn_db = DartDB(self.params.nn_root)
	self.nn = NN(self.nn_db, self.params.nn_db_size, self.params.nn_ignore)
	assert self.params.nn_num <= self.params.nn_query_size
	
    def reshape(self, bottom, top):
	#Reshape tops
	batch_size = bottom[0].shape[0]
	assert self.nn_db.jps.shape[1] == bottom[0].shape[1]
	cur_top = 0
	for nn_id in range(self.params.nn_num):
	    top[cur_top + 0].reshape(batch_size, 3, self.params.nn_shape[0], self.params.nn_shape[1])
	    top[cur_top + 1].reshape(batch_size, 1, self.params.nn_shape[0], self.params.nn_shape[1])
	    top[cur_top + 2].reshape(batch_size, self.nn_db.jps.shape[1])
	    top[cur_top + 3].reshape(batch_size, 1)
	    cur_top += 4
	    #self.top_names.extend(['nn_img_' + str(nn_id), 'nn_seg_' + str(nn_id)])
	    #self.top_names.append('nn_jp_' + str(nn_id))
	    #self.top_names.append('nn_w_' + str(nn_id))
    
    def forward(self, bottom, top):
        for itt in range(bottom[0].shape[0]):
	    jp = bottom[0].data[itt]
	    nn_ids = self.nn.nn_ids(jp, self.params.nn_query_size)
	    if hasattr(nn_ids, '__len__'):
		nn_ids = np.random.choice(nn_ids, size=self.params.nn_num, replace=False)
	    else:
		nn_ids = [nn_ids]
	    
	    for i in range(len(nn_ids)):
		nn_id = nn_ids[i]
		nn_jp, nn_img, nn_seg = self.nn_db.read_instance(nn_id, size=self.params.nn_shape)
		top[i * 4 + 0].data[itt, ...] = nn_img[0].transpose((2,0,1))
		top[i * 4 + 1].data[itt, ...] = nn_seg[0]
		top[i * 4 + 2].data[itt, ...] = nn_jp
		top[i * 4 + 3].data[itt, ...] = 1

    def backward(self, top, propagate_down, bottom):
	bottom[0].diff[...] = 0

    
    def forward_jv(self, top, bottom):
	for top_id in range(self.params.nn_num * 4):
	    top[top_id].diff[...] = 0