def __getitem__(self, index_tuple): index, ratio = index_tuple single_db = [self._roidb[index]] blobs, valid = get_minibatch(single_db) #cv2.imwrite('semseg.png',blobs['semseg_label_0'][0]) #TODO: Check if minibatch is valid ? If not, abandon it. # Need to change _worker_loop in torch.utils.data.dataloader.py. # Squeeze batch dim for key in blobs: if key != 'roidb' and key != 'image_name': blobs[key] = blobs[key].squeeze(axis=0) if self._roidb[index]['need_crop']: self.crop_data(blobs, ratio) # Check bounding box entry = blobs['roidb'][0] boxes = entry['boxes'] invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3]) valid_inds = np.nonzero(~invalid)[0] if len(valid_inds) < len(boxes): for key in [ 'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints' ]: if key in entry: entry[key] = entry[key][valid_inds] entry['segms'] = [entry['segms'][ind] for ind in valid_inds] if not cfg.SEM.SEM_ON and not cfg.DISP.DISP_ON: blobs['roidb'] = blob_utils.serialize( blobs['roidb']) # CHECK: maybe we can serialize in collate_fn #print(blobs.keys()) return blobs
def __getitem__(self, index_tuple): index, ratio = index_tuple single_db = [self._roidb[index]] ## _add_proposals(xxx) blobs, valid = get_minibatch(single_db) #TODO: Check if minibatch is valid ? If not, abandon it. # Need to change _worker_loop in torch.utils.data.dataloader.py. # Squeeze batch dim for key in blobs: if key != 'roidb': blobs[key] = blobs[key].squeeze(axis=0) if self._roidb[index]['need_crop']: self.crop_data(blobs, ratio) # Check bounding box entry = blobs['roidb'][0] boxes = entry['boxes'] invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3]) valid_inds = np.nonzero(~ invalid)[0] if len(valid_inds) < len(boxes): for key in ['boxes', 'precomp_keypoints', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints', 'gt_actions', 'gt_role_id']: if key in entry: entry[key] = entry[key][valid_inds] entry['segms'] = [entry['segms'][ind] for ind in valid_inds] blobs['roidb'] = blob_utils.serialize(blobs['roidb']) # CHECK: maybe we can serialize in collate_fn return blobs
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ if cfg.RPN.RPN_ON: Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch else: Batch = {key: [] for key in get_minibatch_blob_names()} for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): roidb = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] blobs, valid = get_minibatch(roidb) blobs['data'] = torch.from_numpy(blobs['data']) for key in blobs: Batch[key].append(blobs[key]) return Batch
def __getitem__(self, index_tuple): # each time only one roidb go through this index, ratio = index_tuple single_db = [self._roidb[index]] blobs, valid = get_minibatch(single_db) #TODO: Check if minibatch is valid ? If not, abandon it. # Squeeze batch dim for key in blobs: if key != 'roidb': if blobs[key].shape[0] == 1: blobs[key] = blobs[key].squeeze(axis=0) if cfg.RPN.RPN_ON: if self._roidb[index]['need_crop']: self.crop_data(blobs, ratio) # Check bounding box entry = blobs['roidb'][0] boxes = entry['boxes'] invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3]) valid_inds = np.nonzero(~ invalid)[0] if len(valid_inds) < len(boxes): for key in ['boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints']: if key in entry: entry[key] = entry[key][valid_inds] entry['segms'] = [entry['segms'][ind] for ind in valid_inds] blobs['roidb'] = blob_utils.serialize(blobs['roidb']) # CHECK: maybe we can serialize in collate_fn return blobs
def get_next_unlabel_minibatch(self): valid = False while not valid: db_inds = self._get_next_unlabel_minibatch_inds() minibatch_db = [self._roidb_tmp[i] for i in db_inds] blobs, valid = get_minibatch(minibatch_db) return blobs
def get_next_minibatch(self): """Return the blobs to be used for the next minibatch. Thread safe.""" valid = False while not valid: db_inds = self._get_next_minibatch_inds() minibatch_db = [self._roidb[i] for i in db_inds] blobs, valid = get_minibatch(minibatch_db) return blobs
def _get_next_minibatch2(shared_readonly_dict, lock, mp_cur, mp_perm): """Return the blobs to be used for the next minibatch. Thread safe.""" roidb = shared_readonly_dict['roidb'] valid = False while not valid: db_inds = RoIDataLoader._get_next_minibatch_inds( shared_readonly_dict, lock, mp_cur, mp_perm) minibatch_db = [roidb[i] for i in db_inds] blobs, valid = get_minibatch(minibatch_db) return blobs
def get_next_minibatch(self): """Return the blobs to be used for the next minibatch. Thread safe.""" valid = False while not valid: db_inds = self._get_next_minibatch_inds() minibatch_db = [self._roidb[i] for i in db_inds] blobs, valid = get_minibatch(minibatch_db) # for index, i in enumerate(db_inds): # self._roidb[i] = new_roidb[index] return blobs
def __getitem__(self, index_tuple): index, ratio = index_tuple single_db = [self._roidb[index]] blobs, valid = get_minibatch(single_db, self.transform, self.valid_keys) # TODO: Check if minibatch is valid ? If not, abandon it. # Need to change _worker_loop in torch.utils.data.dataloader.py. # Squeeze batch dim for key in blobs: if key != 'roidb': blobs[key] = blobs[key].squeeze(axis=0) if self._roidb[index]['need_crop']: self.crop_data(blobs, ratio) # Check bounding box entry = blobs['roidb'][0] boxes = entry['boxes'] invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3]) valid_inds = np.nonzero(~invalid)[0] if len(valid_inds) < len(boxes): for key in [ 'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints' ]: if key in entry: entry[key] = entry[key][valid_inds] entry['segms'] = [entry['segms'][ind] for ind in valid_inds] if cfg.TRAIN.RANDOM_CROP > 0: if 'segms_origin' not in blobs['roidb'][0].keys(): blobs['roidb'][0]['segms_origin'] = blobs['roidb'][0][ 'segms'].copy() self.crop_data_train(blobs) # Check bounding box, actually, it is not necessary... # entry = blobs['roidb'][0] # boxes = entry['boxes'] # invalid = (boxes[:, 0] < 0) | (boxes[:, 2] < 0) # valid_inds = np.nonzero(~ invalid)[0] # if len(valid_inds) < len(boxes): # for key in ['boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd']: # if key in entry: # entry[key] = entry[key][valid_inds] # # entry['box_to_gt_ind_map'] = np.array(list(range(len(valid_inds)))).astype(int) # entry['segms'] = [entry['segms'][ind] for ind in valid_inds] blobs['roidb'] = blob_utils.serialize( blobs['roidb']) # CHECK: maybe we can serialize in collate_fn return blobs
def __getitem__(self, index_tuple): index, scale = index_tuple single_db = [self._roidb[index]] blobs, valid = get_minibatch(single_db, scale) #TODO: Check if minibatch is valid ? If not, abandon it. # Need to change _worker_loop in torch.utils.data.dataloader.py. # Squeeze batch dim for key in blobs: if key != 'roidb' and key != 'data_flow': blobs[key] = blobs[key].squeeze(axis=0) blobs['roidb'] = blob_utils.serialize(blobs['roidb']) # CHECK: maybe we can serialize in collate_fn return blobs
def _get_next_minibatch(self): """Return the blobs to be used for the next minibatch. DEPRECATED. This only exists for debugging (in train_net.py) and for benchmarking.""" roidb = self._roidb valid = False while not valid: db_inds = self._get_next_minibatch_inds( {'roidb': roidb}, self._lock, multiprocessing.Value('i', self._cur, lock=False), self._perm) minibatch_db = [roidb[i] for i in db_inds] blobs, valid = get_minibatch(minibatch_db) return blobs
def __getitem__(self, index_tuple): index, ratio = index_tuple single_db = [self._roidb[index]] blobs, valid = get_minibatch(single_db, self._num_classes) #TODO: Check if minibatch is valid ? If not, abandon it. # Need to change _worker_loop in torch.utils.data.dataloader.py. # Squeeze batch dim # for key in blobs: # if key != 'roidb': # blobs[key] = blobs[key].squeeze(axis=0) blobs['data'] = blobs['data'].squeeze(axis=0) return blobs
def __getitem__(self, index_tuple): index, ratio = index_tuple single_db = [self._roidb[index]] if cfg.RPN.RPN_ON: blobs, valid = get_minibatch(single_db) # Squeeze batch dim # for key in blobs: # print (key, len(blobs[key]), '--------lala') for key in blobs: if key != 'roidb': blobs[key] = blobs[key].squeeze(axis=0) if self._roidb[index]['need_crop']: self.crop_data(blobs, ratio) # Check bounding box entry = blobs['roidb'][0] boxes = entry['boxes'] invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3]) valid_inds = np.nonzero(~invalid)[0] if len(valid_inds) < len(boxes): for key in [ 'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints' ]: if key in entry: entry[key] = entry[key][valid_inds] entry['segms'] = [ entry['segms'][ind] for ind in valid_inds ] blobs['roidb'] = blob_utils.serialize( blobs['roidb']) # CHECK: maybe we can serialize in collate_fn return blobs else: return self._roidb[index]
def __getitem__(self, index_tuple): index, ratio = index_tuple single_db = [self._roidb[index]] blobs, valid = get_minibatch(single_db) #TODO: Check if minibatch is valid ? If not, abandon it. # Need to change _worker_loop in torch.utils.data.dataloader.py. # Squeeze batch dim for key in blobs: if key != 'roidb' and key != 'gt_cats' and key != 'binary_mask': blobs[key] = blobs[key].squeeze(axis=0) blobs['gt_cats'] = [x for x in blobs['gt_cats'] if x in self.list] blobs['gt_cats'] = np.array(blobs['gt_cats']) scale = blobs['im_info'][-1] mask = cv2.resize(blobs['binary_mask'], None, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) kernel = np.ones((5, 5), np.uint8) mask = cv2.dilate(mask, kernel, iterations=1) blobs['binary_mask'] = mask query_type = 1 if self.training: # Random choice query catgory positive_catgory = blobs['gt_cats'] negative_catgory = np.array( list(set(self.cat_list) - set(positive_catgory))) r = random.random() if r <= cfg.TRAIN.QUERY_POSITIVE_RATE: query_type = 1 cand = np.unique(positive_catgory) if len(cand) == 1: choice = cand[0] else: p = [] for i in cand: p.append(self.show_time[i]) p = np.array(p) p /= p.sum() choice = np.random.choice(cand, 1, p=p)[0] query = self.load_query(choice) elif r > cfg.TRAIN.QUERY_POSITIVE_RATE and r <= cfg.TRAIN.QUERY_POSITIVE_RATE + cfg.TRAIN.QUERY_GLOBAL_NEGATIVE_RATE: query_type = 0 im = blobs['data'].copy() binary_mask = blobs['binary_mask'].copy() patch = self.sample_bg(im, binary_mask) if len(patch) == self.shot: query = patch else: print("No bg, the number of bg is: ", len(patch)) query_type = 0 cand = negative_catgory choice = np.random.choice(cand, 1)[0] query = self.load_query(choice) else: query_type = 0 cand = negative_catgory choice = np.random.choice(cand, 1)[0] query = self.load_query(choice) else: #query = self.load_query(index, single_db[0]['id']) query = self.crop_query(single_db, index, single_db[0]['id']) blobs['query'] = query blobs['query_type'] = query_type if 'gt_cats' in blobs: del blobs['gt_cats'] if 'binary_mask' in blobs: del blobs['binary_mask'] if self.training: if self._roidb[index]['need_crop']: self.crop_data(blobs, ratio) # Check bounding box entry = blobs['roidb'][0] boxes = entry['boxes'] invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3]) valid_inds = np.nonzero(~invalid)[0] if len(valid_inds) < len(boxes): for key in [ 'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints' ]: if key in entry: entry[key] = entry[key][valid_inds] entry['segms'] = [ entry['segms'][ind] for ind in valid_inds ] blobs['roidb'] = blob_utils.serialize( blobs['roidb']) # CHECK: maybe we can serialize in collate_fn return blobs else: blobs['roidb'] = blob_utils.serialize(blobs['roidb']) choice = self.cat_list[index] blobs['choice'] = choice return blobs
def __getitem__(self, index_tuple): index, ratio = index_tuple # this index is just the index of roidb, not the roidb_index # Get support roidb, support cls is same with query cls, and support image is different from query image. #query_cls = self.full_info_list[index, 1] #query_image = self.full_info_list[index, 2] query_cls = self.index_pd.loc[self.index_pd['index'] == index, 'cls_ls'].tolist()[0] query_img = self.index_pd.loc[self.index_pd['index'] == index, 'img_ls'].tolist()[0] all_cls = self.index_pd.loc[self.index_pd['img_ls'] == query_img, 'cls_ls'].tolist() #support_blobs, support_valid = get_minibatch(support_db) single_db = [self._roidb[index]] blobs, valid = get_minibatch(single_db) #TODO: Check if minibatch is valid ? If not, abandon it. # Need to change _worker_loop in torch.utils.data.dataloader.py. # Squeeze batch dim for key in blobs: if key != 'roidb': blobs[key] = blobs[key].squeeze(axis=0) if self._roidb[index]['need_crop']: self.crop_data(blobs, ratio) # Check bounding box entry = blobs['roidb'][0] boxes = entry['boxes'] invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3]) valid_inds = np.nonzero(~invalid)[0] if len(valid_inds) < len(boxes): for key in [ 'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints' ]: if key in entry: entry[key] = entry[key][valid_inds] entry['segms'] = [entry['segms'][ind] for ind in valid_inds] # Crop support data and get new support box in the support data support_way = 2 #2 #5 #2 support_shot = 5 #5 support_data_all = np.zeros((support_way * support_shot, 3, 320, 320), dtype=np.float32) support_box_all = np.zeros((support_way * support_shot, 4), dtype=np.float32) used_img_ls = [query_img] used_index_ls = [index] #used_cls_ls = [query_cls] used_cls_ls = list(set(all_cls)) support_cls_ls = [] mixup_i = 0 for shot in range(support_shot): # Support image and box support_index = self.index_pd.loc[ (self.index_pd['cls_ls'] == query_cls) & (~self.index_pd['img_ls'].isin(used_img_ls)) & (~self.index_pd['index'].isin(used_index_ls)), 'index'].sample(random_state=index).tolist()[0] support_cls = self.index_pd.loc[ self.index_pd['index'] == support_index, 'cls_ls'].tolist()[0] support_img = self.index_pd.loc[ self.index_pd['index'] == support_index, 'img_ls'].tolist()[0] used_index_ls.append(support_index) used_img_ls.append(support_img) support_db = [self._roidb[support_index]] support_data, support_box = self.crop_support(support_db) support_data_all[mixup_i] = support_data support_box_all[mixup_i] = support_box[0] support_cls_ls.append(support_cls) #- 1) #assert support_cls - 1 >= 0 mixup_i += 1 if support_way == 1: pass else: for way in range(support_way - 1): other_cls = self.index_pd.loc[ (~self.index_pd['cls_ls'].isin(used_cls_ls)), 'cls_ls'].drop_duplicates().sample( random_state=index).tolist()[0] used_cls_ls.append(other_cls) for shot in range(support_shot): # Support image and box support_index = self.index_pd.loc[ (self.index_pd['cls_ls'] == other_cls) & (~self.index_pd['img_ls'].isin(used_img_ls)) & (~self.index_pd['index'].isin(used_index_ls)), 'index'].sample(random_state=index).tolist()[0] support_cls = self.index_pd.loc[self.index_pd['index'] == support_index, 'cls_ls'].tolist()[0] support_img = self.index_pd.loc[self.index_pd['index'] == support_index, 'img_ls'].tolist()[0] used_index_ls.append(support_index) used_img_ls.append(support_img) support_db = [self._roidb[support_index]] support_data, support_box = self.crop_support(support_db) support_data_all[mixup_i] = support_data support_box_all[mixup_i] = support_box[0] support_cls_ls.append(support_cls) #- 1) #assert support_cls - 1 >= 0 mixup_i += 1 blobs[ 'support_data'] = support_data_all #final_support_data #support_blobs['data'] blobs['roidb'][0][ 'support_boxes'] = support_box_all #support_blobs['roidb'][0]['boxes'] # only one box blobs['roidb'][0]['support_id'] = support_db[0]['id'] #blobs['roidb'][0]['gt_classes'] = blobs['roidb'][0]['gt_classes'] #np.array([1] * (len(blobs['roidb'][0]['gt_classes']))) blobs['roidb'][0]['support_cls'] = support_cls_ls blobs['roidb'][0]['query_id'] = single_db[0]['id'] blobs['roidb'][0]['target_cls'] = single_db[0]['target_cls'] blobs['roidb'] = blob_utils.serialize( blobs['roidb']) # CHECK: maybe we can serialize in collate_fn return blobs