Beispiel #1
0
    def __getitem__(self, index_tuple):
        index, ratio = index_tuple
        single_db = [self._roidb[index]]
        blobs, valid = get_minibatch(single_db)
        #cv2.imwrite('semseg.png',blobs['semseg_label_0'][0])
        #TODO: Check if minibatch is valid ? If not, abandon it.
        # Need to change _worker_loop in torch.utils.data.dataloader.py.
        # Squeeze batch dim
        for key in blobs:
            if key != 'roidb' and key != 'image_name':
                blobs[key] = blobs[key].squeeze(axis=0)

        if self._roidb[index]['need_crop']:
            self.crop_data(blobs, ratio)
            # Check bounding box
            entry = blobs['roidb'][0]
            boxes = entry['boxes']
            invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:,
                                                                           3])
            valid_inds = np.nonzero(~invalid)[0]
            if len(valid_inds) < len(boxes):
                for key in [
                        'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps',
                        'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints'
                ]:
                    if key in entry:
                        entry[key] = entry[key][valid_inds]
                entry['segms'] = [entry['segms'][ind] for ind in valid_inds]
        if not cfg.SEM.SEM_ON and not cfg.DISP.DISP_ON:
            blobs['roidb'] = blob_utils.serialize(
                blobs['roidb'])  # CHECK: maybe we can serialize in collate_fn
        #print(blobs.keys())
        return blobs
Beispiel #2
0
    def __getitem__(self, index_tuple):
        index, ratio = index_tuple
        single_db = [self._roidb[index]]

        ## _add_proposals(xxx)
        blobs, valid = get_minibatch(single_db)
        #TODO: Check if minibatch is valid ? If not, abandon it.
        # Need to change _worker_loop in torch.utils.data.dataloader.py.

        # Squeeze batch dim
        for key in blobs:
            if key != 'roidb':
                blobs[key] = blobs[key].squeeze(axis=0)

        if self._roidb[index]['need_crop']:
            self.crop_data(blobs, ratio)
            # Check bounding box
            entry = blobs['roidb'][0]
            boxes = entry['boxes']
            invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3])
            valid_inds = np.nonzero(~ invalid)[0]
            if len(valid_inds) < len(boxes):
                for key in ['boxes', 'precomp_keypoints', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd',
                            'box_to_gt_ind_map', 'gt_keypoints', 'gt_actions', 'gt_role_id']:
                    if key in entry:
                        entry[key] = entry[key][valid_inds]
                entry['segms'] = [entry['segms'][ind] for ind in valid_inds]

        blobs['roidb'] = blob_utils.serialize(blobs['roidb'])  # CHECK: maybe we can serialize in collate_fn

        return blobs
Beispiel #3
0
def collate_minibatch(list_of_blobs):
    """Stack samples seperately and return a list of minibatches
    A batch contains NUM_GPUS minibatches and image size in different minibatch may be different.
    Hence, we need to stack smaples from each minibatch seperately.
    """
    if cfg.RPN.RPN_ON:
        Batch = {key: [] for key in list_of_blobs[0]}
        # Because roidb consists of entries of variable length, it can't be batch into a tensor.
        # So we keep roidb in the type of "list of ndarray".
        list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs]
        for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH):
            mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
            # Pad image data
            mini_list = pad_image_data(mini_list)
            minibatch = default_collate(mini_list)
            minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
            for key in minibatch:
                Batch[key].append(minibatch[key])

        return Batch
    else:
        Batch = {key: [] for key in get_minibatch_blob_names()}
        for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH):
            roidb = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
            blobs, valid = get_minibatch(roidb)
            blobs['data'] = torch.from_numpy(blobs['data'])
            for key in blobs:
                Batch[key].append(blobs[key])

        return Batch
    def __getitem__(self, index_tuple): # each time only one roidb go through this
        index, ratio = index_tuple
        single_db = [self._roidb[index]]
        blobs, valid = get_minibatch(single_db)
        #TODO: Check if minibatch is valid ? If not, abandon it.

        # Squeeze batch dim    
        for key in blobs:
            if key != 'roidb':
                if blobs[key].shape[0] == 1:
                    blobs[key] = blobs[key].squeeze(axis=0)
        if cfg.RPN.RPN_ON:
            if self._roidb[index]['need_crop']:
                self.crop_data(blobs, ratio)
                # Check bounding box
                entry = blobs['roidb'][0]
                boxes = entry['boxes']
                invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:, 3])
                valid_inds = np.nonzero(~ invalid)[0]
                if len(valid_inds) < len(boxes):
                    for key in ['boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd',
                                'box_to_gt_ind_map', 'gt_keypoints']:
                        if key in entry:
                            entry[key] = entry[key][valid_inds]
                    entry['segms'] = [entry['segms'][ind] for ind in valid_inds]

            blobs['roidb'] = blob_utils.serialize(blobs['roidb'])  # CHECK: maybe we can serialize in collate_fn

        return blobs
Beispiel #5
0
 def get_next_unlabel_minibatch(self):
     valid = False
     
     while not valid:
         db_inds = self._get_next_unlabel_minibatch_inds()
         minibatch_db = [self._roidb_tmp[i] for i in db_inds]
         blobs, valid = get_minibatch(minibatch_db)
     return blobs
Beispiel #6
0
 def get_next_minibatch(self):
     """Return the blobs to be used for the next minibatch. Thread safe."""
     valid = False
     while not valid:
         db_inds = self._get_next_minibatch_inds()
         minibatch_db = [self._roidb[i] for i in db_inds]
         blobs, valid = get_minibatch(minibatch_db)
     return blobs
Beispiel #7
0
 def get_next_minibatch(self):
     """Return the blobs to be used for the next minibatch. Thread safe."""
     valid = False
     while not valid:
         db_inds = self._get_next_minibatch_inds()
         minibatch_db = [self._roidb[i] for i in db_inds]
         blobs, valid = get_minibatch(minibatch_db)
     return blobs
Beispiel #8
0
 def _get_next_minibatch2(shared_readonly_dict, lock, mp_cur, mp_perm):
     """Return the blobs to be used for the next minibatch. Thread safe."""
     roidb = shared_readonly_dict['roidb']
     valid = False
     while not valid:
         db_inds = RoIDataLoader._get_next_minibatch_inds(
             shared_readonly_dict, lock, mp_cur, mp_perm)
         minibatch_db = [roidb[i] for i in db_inds]
         blobs, valid = get_minibatch(minibatch_db)
     return blobs
 def get_next_minibatch(self):
     """Return the blobs to be used for the next minibatch. Thread safe."""
     valid = False
     while not valid:
         db_inds = self._get_next_minibatch_inds()
         minibatch_db = [self._roidb[i] for i in db_inds]
         blobs, valid = get_minibatch(minibatch_db)
         # for index, i in enumerate(db_inds):
         #     self._roidb[i] = new_roidb[index]
     return blobs
Beispiel #10
0
 def _get_next_minibatch2(shared_readonly_dict, lock, mp_cur, mp_perm):
     """Return the blobs to be used for the next minibatch. Thread safe."""
     roidb = shared_readonly_dict['roidb']
     valid = False
     while not valid:
         db_inds = RoIDataLoader._get_next_minibatch_inds(
             shared_readonly_dict, lock, mp_cur, mp_perm)
         minibatch_db = [roidb[i] for i in db_inds]
         blobs, valid = get_minibatch(minibatch_db)
     return blobs
    def __getitem__(self, index_tuple):
        index, ratio = index_tuple
        single_db = [self._roidb[index]]
        blobs, valid = get_minibatch(single_db, self.transform,
                                     self.valid_keys)
        # TODO: Check if minibatch is valid ? If not, abandon it.
        # Need to change _worker_loop in torch.utils.data.dataloader.py.

        # Squeeze batch dim
        for key in blobs:
            if key != 'roidb':
                blobs[key] = blobs[key].squeeze(axis=0)

        if self._roidb[index]['need_crop']:
            self.crop_data(blobs, ratio)
            # Check bounding box
            entry = blobs['roidb'][0]
            boxes = entry['boxes']
            invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:,
                                                                           3])
            valid_inds = np.nonzero(~invalid)[0]
            if len(valid_inds) < len(boxes):
                for key in [
                        'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps',
                        'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints'
                ]:
                    if key in entry:
                        entry[key] = entry[key][valid_inds]
                entry['segms'] = [entry['segms'][ind] for ind in valid_inds]

        if cfg.TRAIN.RANDOM_CROP > 0:
            if 'segms_origin' not in blobs['roidb'][0].keys():
                blobs['roidb'][0]['segms_origin'] = blobs['roidb'][0][
                    'segms'].copy()

            self.crop_data_train(blobs)
            # Check bounding box, actually, it is not necessary...
            # entry = blobs['roidb'][0]
            # boxes = entry['boxes']
            # invalid = (boxes[:, 0] < 0) | (boxes[:, 2] < 0)
            # valid_inds = np.nonzero(~ invalid)[0]
            # if len(valid_inds) < len(boxes):
            #     for key in ['boxes', 'gt_classes', 'seg_areas', 'gt_overlaps', 'is_crowd']:
            #         if key in entry:
            #             entry[key] = entry[key][valid_inds]
            #
            #     entry['box_to_gt_ind_map'] = np.array(list(range(len(valid_inds)))).astype(int)
            #     entry['segms'] = [entry['segms'][ind] for ind in valid_inds]

        blobs['roidb'] = blob_utils.serialize(
            blobs['roidb'])  # CHECK: maybe we can serialize in collate_fn
        return blobs
Beispiel #12
0
    def __getitem__(self, index_tuple):  
        index, scale = index_tuple
        single_db = [self._roidb[index]]
        blobs, valid = get_minibatch(single_db, scale)
        #TODO: Check if minibatch is valid ? If not, abandon it.
        # Need to change _worker_loop in torch.utils.data.dataloader.py.
        # Squeeze batch dim
        for key in blobs:
            if key != 'roidb' and key != 'data_flow':
                blobs[key] = blobs[key].squeeze(axis=0)

        blobs['roidb'] = blob_utils.serialize(blobs['roidb'])  # CHECK: maybe we can serialize in collate_fn
        return blobs
Beispiel #13
0
 def _get_next_minibatch(self):
     """Return the blobs to be used for the next minibatch. DEPRECATED.
     This only exists for debugging (in train_net.py) and for
     benchmarking."""
     roidb = self._roidb
     valid = False
     while not valid:
         db_inds = self._get_next_minibatch_inds(
             {'roidb': roidb}, self._lock,
             multiprocessing.Value('i', self._cur, lock=False), self._perm)
         minibatch_db = [roidb[i] for i in db_inds]
         blobs, valid = get_minibatch(minibatch_db)
     return blobs
Beispiel #14
0
    def __getitem__(self, index_tuple):
        index, ratio = index_tuple
        single_db = [self._roidb[index]]
        blobs, valid = get_minibatch(single_db, self._num_classes)
        #TODO: Check if minibatch is valid ? If not, abandon it.
        # Need to change _worker_loop in torch.utils.data.dataloader.py.

        # Squeeze batch dim
        # for key in blobs:
        #     if key != 'roidb':
        #         blobs[key] = blobs[key].squeeze(axis=0)
        blobs['data'] = blobs['data'].squeeze(axis=0)

        return blobs
Beispiel #15
0
 def _get_next_minibatch(self):
     """Return the blobs to be used for the next minibatch. DEPRECATED.
     This only exists for debugging (in train_net.py) and for
     benchmarking."""
     roidb = self._roidb
     valid = False
     while not valid:
         db_inds = self._get_next_minibatch_inds(
             {'roidb': roidb}, self._lock,
             multiprocessing.Value('i', self._cur, lock=False),
             self._perm)
         minibatch_db = [roidb[i] for i in db_inds]
         blobs, valid = get_minibatch(minibatch_db)
     return blobs
Beispiel #16
0
    def __getitem__(self, index_tuple):
        index, ratio = index_tuple
        single_db = [self._roidb[index]]

        if cfg.RPN.RPN_ON:
            blobs, valid = get_minibatch(single_db)
            # Squeeze batch dim
            # for key in blobs:
            #     print (key, len(blobs[key]), '--------lala')
            for key in blobs:
                if key != 'roidb':
                    blobs[key] = blobs[key].squeeze(axis=0)

            if self._roidb[index]['need_crop']:
                self.crop_data(blobs, ratio)
                # Check bounding box
                entry = blobs['roidb'][0]
                boxes = entry['boxes']
                invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1]
                                                          == boxes[:, 3])
                valid_inds = np.nonzero(~invalid)[0]
                if len(valid_inds) < len(boxes):
                    for key in [
                            'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps',
                            'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints'
                    ]:
                        if key in entry:
                            entry[key] = entry[key][valid_inds]
                    entry['segms'] = [
                        entry['segms'][ind] for ind in valid_inds
                    ]

            blobs['roidb'] = blob_utils.serialize(
                blobs['roidb'])  # CHECK: maybe we can serialize in collate_fn

            return blobs
        else:
            return self._roidb[index]
Beispiel #17
0
    def __getitem__(self, index_tuple):
        index, ratio = index_tuple
        single_db = [self._roidb[index]]
        blobs, valid = get_minibatch(single_db)
        #TODO: Check if minibatch is valid ? If not, abandon it.
        # Need to change _worker_loop in torch.utils.data.dataloader.py.

        # Squeeze batch dim
        for key in blobs:
            if key != 'roidb' and key != 'gt_cats' and key != 'binary_mask':
                blobs[key] = blobs[key].squeeze(axis=0)

        blobs['gt_cats'] = [x for x in blobs['gt_cats'] if x in self.list]
        blobs['gt_cats'] = np.array(blobs['gt_cats'])
        scale = blobs['im_info'][-1]
        mask = cv2.resize(blobs['binary_mask'],
                          None,
                          None,
                          fx=scale,
                          fy=scale,
                          interpolation=cv2.INTER_NEAREST)
        kernel = np.ones((5, 5), np.uint8)
        mask = cv2.dilate(mask, kernel, iterations=1)
        blobs['binary_mask'] = mask
        query_type = 1

        if self.training:
            # Random choice query catgory
            positive_catgory = blobs['gt_cats']
            negative_catgory = np.array(
                list(set(self.cat_list) - set(positive_catgory)))

            r = random.random()
            if r <= cfg.TRAIN.QUERY_POSITIVE_RATE:
                query_type = 1
                cand = np.unique(positive_catgory)
                if len(cand) == 1:
                    choice = cand[0]
                else:
                    p = []
                    for i in cand:
                        p.append(self.show_time[i])
                    p = np.array(p)
                    p /= p.sum()
                    choice = np.random.choice(cand, 1, p=p)[0]
                query = self.load_query(choice)
            elif r > cfg.TRAIN.QUERY_POSITIVE_RATE and r <= cfg.TRAIN.QUERY_POSITIVE_RATE + cfg.TRAIN.QUERY_GLOBAL_NEGATIVE_RATE:
                query_type = 0
                im = blobs['data'].copy()
                binary_mask = blobs['binary_mask'].copy()
                patch = self.sample_bg(im, binary_mask)
                if len(patch) == self.shot:
                    query = patch
                else:
                    print("No bg, the number of bg is: ", len(patch))
                    query_type = 0
                    cand = negative_catgory
                    choice = np.random.choice(cand, 1)[0]
                    query = self.load_query(choice)
            else:
                query_type = 0
                cand = negative_catgory
                choice = np.random.choice(cand, 1)[0]
                query = self.load_query(choice)

        else:
            #query = self.load_query(index, single_db[0]['id'])
            query = self.crop_query(single_db, index, single_db[0]['id'])

        blobs['query'] = query
        blobs['query_type'] = query_type

        if 'gt_cats' in blobs:
            del blobs['gt_cats']
        if 'binary_mask' in blobs:
            del blobs['binary_mask']

        if self.training:
            if self._roidb[index]['need_crop']:
                self.crop_data(blobs, ratio)
                # Check bounding box
                entry = blobs['roidb'][0]
                boxes = entry['boxes']
                invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1]
                                                          == boxes[:, 3])
                valid_inds = np.nonzero(~invalid)[0]
                if len(valid_inds) < len(boxes):
                    for key in [
                            'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps',
                            'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints'
                    ]:
                        if key in entry:
                            entry[key] = entry[key][valid_inds]
                    entry['segms'] = [
                        entry['segms'][ind] for ind in valid_inds
                    ]

            blobs['roidb'] = blob_utils.serialize(
                blobs['roidb'])  # CHECK: maybe we can serialize in collate_fn

            return blobs
        else:
            blobs['roidb'] = blob_utils.serialize(blobs['roidb'])
            choice = self.cat_list[index]
            blobs['choice'] = choice
            return blobs
Beispiel #18
0
    def __getitem__(self, index_tuple):
        index, ratio = index_tuple  # this index is just the index of roidb, not the roidb_index
        # Get support roidb, support cls is same with query cls, and support image is different from query image.
        #query_cls = self.full_info_list[index, 1]
        #query_image = self.full_info_list[index, 2]
        query_cls = self.index_pd.loc[self.index_pd['index'] == index,
                                      'cls_ls'].tolist()[0]
        query_img = self.index_pd.loc[self.index_pd['index'] == index,
                                      'img_ls'].tolist()[0]
        all_cls = self.index_pd.loc[self.index_pd['img_ls'] == query_img,
                                    'cls_ls'].tolist()
        #support_blobs, support_valid = get_minibatch(support_db)
        single_db = [self._roidb[index]]
        blobs, valid = get_minibatch(single_db)
        #TODO: Check if minibatch is valid ? If not, abandon it.
        # Need to change _worker_loop in torch.utils.data.dataloader.py.
        # Squeeze batch dim
        for key in blobs:
            if key != 'roidb':
                blobs[key] = blobs[key].squeeze(axis=0)

        if self._roidb[index]['need_crop']:
            self.crop_data(blobs, ratio)
            # Check bounding box
            entry = blobs['roidb'][0]
            boxes = entry['boxes']
            invalid = (boxes[:, 0] == boxes[:, 2]) | (boxes[:, 1] == boxes[:,
                                                                           3])
            valid_inds = np.nonzero(~invalid)[0]
            if len(valid_inds) < len(boxes):
                for key in [
                        'boxes', 'gt_classes', 'seg_areas', 'gt_overlaps',
                        'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints'
                ]:
                    if key in entry:
                        entry[key] = entry[key][valid_inds]
                entry['segms'] = [entry['segms'][ind] for ind in valid_inds]
        # Crop support data and get new support box in the support data
        support_way = 2  #2 #5 #2
        support_shot = 5  #5
        support_data_all = np.zeros((support_way * support_shot, 3, 320, 320),
                                    dtype=np.float32)
        support_box_all = np.zeros((support_way * support_shot, 4),
                                   dtype=np.float32)
        used_img_ls = [query_img]
        used_index_ls = [index]
        #used_cls_ls = [query_cls]
        used_cls_ls = list(set(all_cls))
        support_cls_ls = []
        mixup_i = 0

        for shot in range(support_shot):
            # Support image and box
            support_index = self.index_pd.loc[
                (self.index_pd['cls_ls'] == query_cls) &
                (~self.index_pd['img_ls'].isin(used_img_ls)) &
                (~self.index_pd['index'].isin(used_index_ls)),
                'index'].sample(random_state=index).tolist()[0]
            support_cls = self.index_pd.loc[
                self.index_pd['index'] == support_index, 'cls_ls'].tolist()[0]
            support_img = self.index_pd.loc[
                self.index_pd['index'] == support_index, 'img_ls'].tolist()[0]
            used_index_ls.append(support_index)
            used_img_ls.append(support_img)

            support_db = [self._roidb[support_index]]
            support_data, support_box = self.crop_support(support_db)
            support_data_all[mixup_i] = support_data
            support_box_all[mixup_i] = support_box[0]
            support_cls_ls.append(support_cls)  #- 1)
            #assert support_cls - 1 >= 0
            mixup_i += 1

        if support_way == 1:
            pass
        else:
            for way in range(support_way - 1):
                other_cls = self.index_pd.loc[
                    (~self.index_pd['cls_ls'].isin(used_cls_ls)),
                    'cls_ls'].drop_duplicates().sample(
                        random_state=index).tolist()[0]
                used_cls_ls.append(other_cls)
                for shot in range(support_shot):
                    # Support image and box

                    support_index = self.index_pd.loc[
                        (self.index_pd['cls_ls'] == other_cls) &
                        (~self.index_pd['img_ls'].isin(used_img_ls)) &
                        (~self.index_pd['index'].isin(used_index_ls)),
                        'index'].sample(random_state=index).tolist()[0]

                    support_cls = self.index_pd.loc[self.index_pd['index'] ==
                                                    support_index,
                                                    'cls_ls'].tolist()[0]
                    support_img = self.index_pd.loc[self.index_pd['index'] ==
                                                    support_index,
                                                    'img_ls'].tolist()[0]

                    used_index_ls.append(support_index)
                    used_img_ls.append(support_img)
                    support_db = [self._roidb[support_index]]
                    support_data, support_box = self.crop_support(support_db)
                    support_data_all[mixup_i] = support_data
                    support_box_all[mixup_i] = support_box[0]
                    support_cls_ls.append(support_cls)  #- 1)
                    #assert support_cls - 1 >= 0
                    mixup_i += 1

        blobs[
            'support_data'] = support_data_all  #final_support_data #support_blobs['data']
        blobs['roidb'][0][
            'support_boxes'] = support_box_all  #support_blobs['roidb'][0]['boxes'] # only one box
        blobs['roidb'][0]['support_id'] = support_db[0]['id']
        #blobs['roidb'][0]['gt_classes'] = blobs['roidb'][0]['gt_classes'] #np.array([1] * (len(blobs['roidb'][0]['gt_classes'])))

        blobs['roidb'][0]['support_cls'] = support_cls_ls
        blobs['roidb'][0]['query_id'] = single_db[0]['id']
        blobs['roidb'][0]['target_cls'] = single_db[0]['target_cls']

        blobs['roidb'] = blob_utils.serialize(
            blobs['roidb'])  # CHECK: maybe we can serialize in collate_fn
        return blobs