class DataLoader(object): def __init__(self, params, db, nn_db): self.lock = Lock() self.db = db self.cur = db.length self.im_shape = params['im_shape'] self.nn_shape = params['nn_shape'] self.hist_eq = params['hist_eq'] self.indexes = np.arange(db.length) self.shuffle = params['shuffle'] self.subtract_mean = params['subtract_mean'] if self.subtract_mean: self.mean_img = self.db.read_mean_img(self.im_shape) self.im_shape = params['im_shape'] self.load_nn = params['load_nn'] self.nn_query_size = params['nn_query_size'] if self.load_nn: self.nn_db = nn_db #nn_ignore = 1 if db.db_root == nn_db.db_root else 0 nn_ignore = 0 self.nn = NN(nn_db, params['nn_db_size'], nn_ignore) def load_next_data(self): nid = self.get_next_id() jp, imgs, segs = self.db.read_instance(nid, size=self.im_shape) item = {'jp':jp} for i in xrange(len(imgs)): img = imgs[i] if self.hist_eq: img = correct_hist(img) item.update({'img_' + shape_str(self.im_shape[i]):img.transpose((2,0,1)), 'seg_' + shape_str(self.im_shape[i]): segs[i]}) if self.load_nn: nn_id = self.nn.nn_ids(jp, self.nn_query_size) if hasattr(nn_id, '__len__'): nn_id = random.choice(nn_id) nn_jp, nn_imgs, nn_segs = self.nn_db.read_instance(nn_id, size=self.nn_shape) item.update({'nn_jp':nn_jp}) for i in xrange(len(nn_imgs)): nn_img = nn_imgs[i] if self.hist_eq: nn_img = correct_hist(nn_img) item.update({'nn_img_' + shape_str(self.nn_shape[i]):nn_img.transpose((2,0,1)), 'nn_seg_' + shape_str(self.nn_shape[i]): nn_segs[i]}) return item def get_next_id(self): self.lock.acquire() if self.cur >= len(self.indexes) - 1: self.cur = 0 if self.shuffle: random.shuffle(self.indexes) else: self.cur += 1 self.lock.release() return self.indexes[self.cur]
class EKFNNLayer(caffe.Layer): def setup(self, bottom, top): params = eval(self.param_str) check_params(params, nn_root=None, nn_shape=None, nn_query_size=1, nn_num=1, nn_db_size=np.inf, nn_ignore=1) self.params = Map(params) self.nn_db = DartDB(self.params.nn_root) self.nn = NN(self.nn_db, self.params.nn_db_size, self.params.nn_ignore) assert self.params.nn_num <= self.params.nn_query_size def reshape(self, bottom, top): #Reshape tops batch_size = bottom[0].shape[0] assert self.nn_db.jps.shape[1] == bottom[0].shape[1] cur_top = 0 for nn_id in range(self.params.nn_num): top[cur_top + 0].reshape(batch_size, 3, self.params.nn_shape[0], self.params.nn_shape[1]) top[cur_top + 1].reshape(batch_size, 1, self.params.nn_shape[0], self.params.nn_shape[1]) top[cur_top + 2].reshape(batch_size, self.nn_db.jps.shape[1]) top[cur_top + 3].reshape(batch_size, 1) cur_top += 4 #self.top_names.extend(['nn_img_' + str(nn_id), 'nn_seg_' + str(nn_id)]) #self.top_names.append('nn_jp_' + str(nn_id)) #self.top_names.append('nn_w_' + str(nn_id)) def forward(self, bottom, top): for itt in range(bottom[0].shape[0]): jp = bottom[0].data[itt] nn_ids = self.nn.nn_ids(jp, self.params.nn_query_size) if hasattr(nn_ids, '__len__'): nn_ids = np.random.choice(nn_ids, size=self.params.nn_num, replace=False) else: nn_ids = [nn_ids] for i in range(len(nn_ids)): nn_id = nn_ids[i] nn_jp, nn_img, nn_seg = self.nn_db.read_instance(nn_id, size=self.params.nn_shape) top[i * 4 + 0].data[itt, ...] = nn_img[0].transpose((2,0,1)) top[i * 4 + 1].data[itt, ...] = nn_seg[0] top[i * 4 + 2].data[itt, ...] = nn_jp top[i * 4 + 3].data[itt, ...] = 1 def backward(self, top, propagate_down, bottom): bottom[0].diff[...] = 0 def forward_jv(self, top, bottom): for top_id in range(self.params.nn_num * 4): top[top_id].diff[...] = 0