Esempio n. 1
0
 def showRef(self, ref, seg_box='seg'):
     ax = plt.gca()
     # show image
     image = self.Imgs[ref['image_id']]
     I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
     ax.imshow(I)
     # show refer expression
     for sid, sent in enumerate(ref['sentences']):
         print('%s. %s' % (sid + 1, sent['sent']))
     # show segmentations
     if seg_box == 'seg':
         ann_id = ref['ann_id']
         ann = self.Anns[ann_id]
         polygons = []
         color = []
         c = 'none'
         if type(ann['segmentation'][0]) == list:
             # polygon used for refcoco*
             for seg in ann['segmentation']:
                 print(seg)
                 poly = np.array(seg).reshape((int(len(seg) / 2), 2))
                 polygons.append(Polygon(poly, True, alpha=0.4))
                 color.append(c)
             p = PatchCollection(polygons,
                                 facecolors=color,
                                 edgecolors=(1, 1, 0, 0),
                                 linewidths=3,
                                 alpha=1)
             ax.add_collection(p)  # thick yellow polygon
             p = PatchCollection(polygons,
                                 facecolors=color,
                                 edgecolors=(1, 0, 0, 0),
                                 linewidths=1,
                                 alpha=1)
             ax.add_collection(p)  # thin red polygon
         else:
             # mask used for refclef
             rle = ann['segmentation']
             m = mask.decode(rle)
             img = np.ones((m.shape[0], m.shape[1], 3))
             color_mask = np.array([2.0, 166.0, 101.0]) / 255
             for i in range(3):
                 img[:, :, i] = color_mask[i]
             ax.imshow(np.dstack((img, m * 0.5)))
     # show bounding-box
     elif seg_box == 'box':
         ann_id = ref['ann_id']
         ann = self.Anns[ann_id]
         bbox = self.getRefBox(ref['ref_id'])
         box_plot = Rectangle((bbox[0], bbox[1]),
                              bbox[2],
                              bbox[3],
                              fill=False,
                              edgecolor='green',
                              linewidth=3)
         ax.add_patch(box_plot)
Esempio n. 2
0
	def getMask(self, ref):
		# return mask, area and mask-center
		ann = self.refToAnn[ref['ref_id']]
		image = self.Imgs[ref['image_id']]
		if type(ann['segmentation'][0]) == list: # polygon
			rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width'])
		else:
			rle = ann['segmentation']
		m = mask.decode(rle)
		m = np.sum(m, axis=2)  # sometimes there are multiple binary map (corresponding to multiple segs)
		m = m.astype(np.uint8) # convert to np.uint8
		# compute area
		area = sum(mask.area(rle))  # should be close to ann['area']
		return {'mask': m, 'area': area}
Esempio n. 3
0
	def getMask(self, ref):
		# return mask, area and mask-center
		ann = self.refToAnn[ref['ref_id']]
		image = self.Imgs[ref['image_id']]
		if type(ann['segmentation'][0]) == list: # polygon
			rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width'])
		else:
			rle = ann['segmentation']
		m = mask.decode(rle)
		m = np.sum(m, axis=2)  # sometimes there are multiple binary map (corresponding to multiple segs)
		m = m.astype(np.uint8) # convert to np.uint8
		# compute area
		area = sum(mask.area(rle))  # should be close to ann['area']
		return {'mask': m, 'area': area}
Esempio n. 4
0
	def showRef(self, ref, seg_box='seg'):
		ax = plt.gca()
		# show image
		image = self.Imgs[ref['image_id']]
		I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
		ax.imshow(I)
		# show refer expression
		for sid, sent in enumerate(ref['sentences']):
			print '%s. %s' % (sid+1, sent['sent'])
		# show segmentations
		if seg_box == 'seg':
			ann_id = ref['ann_id']
			ann = self.Anns[ann_id]
			polygons = []
			color = []
			c = 'none'
			if type(ann['segmentation'][0]) == list:
				# polygon used for refcoco*
				for seg in ann['segmentation']:
					poly = np.array(seg).reshape((len(seg)/2, 2))
					polygons.append(Polygon(poly, True, alpha=0.4))
					color.append(c)
				p = PatchCollection(polygons, facecolors=color, edgecolors=(1,1,0,0), linewidths=3, alpha=1)
				ax.add_collection(p)  # thick yellow polygon
				p = PatchCollection(polygons, facecolors=color, edgecolors=(1,0,0,0), linewidths=1, alpha=1)
				ax.add_collection(p)  # thin red polygon
			else:
				# mask used for refclef
				rle = ann['segmentation']
				m = mask.decode(rle)
				img = np.ones( (m.shape[0], m.shape[1], 3) )
				color_mask = np.array([2.0,166.0,101.0])/255
				for i in range(3):
					img[:,:,i] = color_mask[i]
				ax.imshow(np.dstack( (img, m*0.5) ))
		# show bounding-box
		elif seg_box == 'box':
			ann_id = ref['ann_id']
			ann = self.Anns[ann_id]
			bbox = 	self.getRefBox(ref['ref_id'])
			box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
			ax.add_patch(box_plot)
Esempio n. 5
0
def generateMasksAndQueries(MASK_DIR, QUERY_FILE_NAME, refer_obj):
	img_sg_counts = {}
	query_file_str = {}
	print "Creating {} mask files..".format(len(refer_obj.data['annotations']))
	actual_processed = 0
	for i,ann in enumerate(refer_obj.data['annotations']):
		if i % 1000 == 0:
			with open(QUERY_FILE_NAME, 'w') as fp:
				json.dump(query_file_str, fp)
			print "On annotation #{}, processed = {}".format(i,actual_processed)
		if ann['iscrowd'] == 1 or ann['id'] not in refer_obj.annToRef:
			continue
		actual_processed += 1
		image = refer_obj.Imgs[ann['image_id']]
		if type(ann['segmentation']) == list: # polygon
			rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width'])
		else:
			rle = ann['segmentation']
		m = mask.decode(rle)
		m = np.sum(m, axis=2)
		matlab_dict = {'segimg_t': m}
		if ann['image_id'] in img_sg_counts:
			img_sg_counts[ann['image_id']] += 1
			filename = "{}_{}".format(ann['image_id'],img_sg_counts[ann['image_id']])
		else:
			img_sg_counts[ann['image_id']] = 1
			filename = "{}_{}".format(ann['image_id'],img_sg_counts[ann['image_id']])
		io.savemat(osp.join(MASK_DIR,filename),matlab_dict)
		ref = refer_obj.annToRef[ann['id']]
		if filename in query_file_str:
			val = query_file_str[filename] 
		else:
			val = []
		for sentence in ref['sentences']:
			val.append(sentence['raw'])
		query_file_str[filename] = val
Esempio n. 6
0
    def getTestBatch(self, split):  # revised from getBatch
        # options
        batch_size = 1  #opt.get('batch_size', 5) ####
        seq_per_ref = 1  #3 #opt.get('seq_per_ref', 3)
        split_ix = self.split_ix[split]
        max_index = len(split_ix) - 1  # don't forget to -1
        wrapped = False

        # fetch image_ids
        batch_image_ids = []
        for i in range(batch_size):
            ri = self.iterators[split]
            ri_next = ri + 1
            if ri_next > max_index:
                ri_next = 0
                wrapped = True
            self.iterators[split] = ri_next
            image_id = split_ix[ri]
            batch_image_ids += [image_id]

        # fetch feats
        batch_pos_ann_ids, batch_pos_sent_ids = [], []
        batch_pos_category_ids = []
        first = 1

        blob, im_scales = self._get_image_blob(batch_image_ids)

        for image_id in batch_image_ids:
            ref_ids = self.Images[image_id]['ref_ids']
            file_name = self.Images[image_id]['file_name']  ####

            # get image related ids
            #image_pos_ann_ids, image_neg_ann_ids = [], []
            image_pos_ref_ids = []  #### ++++

            for ref_id in ref_ids:
                ref_ann_id = self.Refs[ref_id]['ann_id']
                ref_category_id = self.Refs[ref_id]['category_id']

                rle = self.Refs[ref_id]['rle']
                m = mask.decode(rle)

                #print(ref_id, np.unique(m), m.shape)
                #for i in range(m.shape[2]):
                #  im = Image.fromarray((m[:,:,i]*255).astype(np.uint8), mode='P')
                #  im.save('fig/{}_{}_{}.png'.format(image_id, ref_id, i))

                m = np.sum(
                    m, axis=2
                )  # sometimes there are multiple binary map (corresponding to multiple segs)
                m[m > 0] = 1

                ref_mask = m.astype(np.uint8)
                ref_mask = imresize(ref_mask,
                                    size=(blob.shape[1], blob.shape[2]),
                                    interp='nearest')
                ref_mask = np.expand_dims(ref_mask, axis=0)

                # pos ids
                #pos_ann_ids = [ref_ann_id] * seq_per_ref
                pos_ref_ids = [ref_id] * seq_per_ref  #### ++++
                pos_category_ids = [ref_category_id] * seq_per_ref
                #pos_sent_ids = self.fetch_sent_ids_by_ref_id(ref_id, seq_per_ref) #### random choose 3 sentence for each reference object (bbox)

                #sent_ids = list(self.Refs[ref_id]['sent_ids']) ####
                #pos_sent_ids = [random.choice(sent_ids)] ################

                for sent_id in self.Refs[ref_id]['sent_ids']:
                    # add to image and batch
                    #image_pos_ann_ids += pos_ann_ids
                    image_pos_ref_ids += pos_ref_ids  #### ++++
                    batch_pos_sent_ids += [sent_id]  #pos_sent_ids
                    batch_pos_category_ids += pos_category_ids
                    if first:
                        batch_mask = ref_mask
                        first = 0
                    else:
                        batch_mask = np.concatenate((batch_mask, ref_mask),
                                                    axis=0)

            # fetch feats
            #pos_ann_boxes = xywh_to_xyxy(np.vstack([self.Anns[ann_id]['box'] for ann_id in image_pos_ann_ids]))
            pos_ref_boxes = xywh_to_xyxy(
                np.vstack([
                    self.Refs[ref_id]['box'] for ref_id in image_pos_ref_ids
                ]))  #### ++++
            #print('----------------------')
            #print(pos_ann_boxes)
            #print(pos_ref_boxes) #### ++++

        # get feats and labels
        pos_labels = np.vstack(
            [self.fetch_seq(sent_id) for sent_id in batch_pos_sent_ids])

        # convert to Variable
        pos_labels = Variable(torch.from_numpy(pos_labels).long().cuda())

        # chunk pos_labels and neg_labels using max_len
        max_len = (pos_labels != 0).sum(1).max().data[0]
        pos_labels = pos_labels[:, :max_len]

        # return
        data = {}
        data['data'] = blob
        data['im_info'] = np.array(
            [[blob.shape[1], blob.shape[2], im_scales[0]]]).astype(np.float32)

        pos_ref_boxes = np.concatenate(
            (pos_ref_boxes * im_scales[0], np.array([batch_pos_category_ids
                                                     ]).T),
            axis=1)
        data['gt_boxes'] = pos_ref_boxes.astype(np.float32)  ####
        data['gt_masks'] = batch_mask  ####

        data['labels'] = pos_labels  ####
        data['file_name'] = file_name
        data['bounds'] = {
            'it_pos_now': ri,
            'it_max': max_index,
            'wrapped': wrapped
        }

        return data
Esempio n. 7
0
    def getBatch(self, split, batch_size=1):  #, opt):
        # options
        #print('batch_size:', batch_size)
        #batch_size = batch_size #opt.get('batch_size', 5) ####
        #seq_per_ref = 1 #3 #opt.get('seq_per_ref', 3)
        #sample_ratio = 0.3 #opt.get('visual_sample_ratio', 0.3)  # sample ratio, st vs dt
        split_ix = self.split_ix[split]
        max_index = len(split_ix) - 1  # don't forget to -1
        wrapped = False

        # fetch image_ids
        batch_image_ids = []
        for i in range(batch_size):
            ri = self.iterators[split]
            ri_next = ri + 1
            if ri_next > max_index:
                print('number of images in split {}: {}'.format(
                    split, len(split_ix)))
                self.perm[split] = np.random.permutation(len(split_ix))
                print('perm', split, 'shuffled:', self.perm[split])
                ri_next = 0
                wrapped = True
            self.iterators[split] = ri_next
            image_id = split_ix[self.perm[split][ri]]
            batch_image_ids += [image_id]

        # fetch feats
        #batch_ref_ids = []
        batch_pos_ref_ids = []
        batch_pos_sent_ids = []
        #batch_pos_ann_ids, batch_pos_sent_ids, batch_pos_pool5, batch_pos_fc7, batch_pos_C4_feat = [], [], [], [], []
        #batch_pos_cxt_fc7, batch_pos_cxt_lfeats = [], []
        #batch_neg_ann_ids, batch_neg_sent_ids, batch_neg_pool5, batch_neg_fc7, batch_neg_C4_feat = [], [], [], [], []
        #batch_neg_cxt_fc7, batch_neg_cxt_lfeats = [], []
        batch_pos_category_ids = []
        first = 1

        blob, im_scales = self._get_image_blob(batch_image_ids)

        for image_id in batch_image_ids:
            ref_ids = self.Images[image_id]['ref_ids']
            #ref_ids = [random.choice(ref_ids)] ####
            file_name = self.Images[image_id]['file_name']  ####
            #batch_ref_ids += self.expand_list(ref_ids, seq_per_ref)
            # fetch head and im_info
            #head, im_info = self.image_to_head(image_id)
            #head = Variable(torch.from_numpy(head).cuda())

            # get image related ids
            #image_pos_ann_ids, image_neg_ann_ids = [], []
            #image_pos_ref_ids = [] #### ++++

            for ref_id in ref_ids:
                #ref_ann_id = self.Refs[ref_id]['ann_id']
                ref_category_id = self.Refs[ref_id]['category_id']

                rle = self.Refs[ref_id]['rle']
                m = mask.decode(rle)

                #print(ref_id, np.unique(m), m.shape)
                #for i in range(m.shape[2]):
                #  im = Image.fromarray((m[:,:,i]*255).astype(np.uint8), mode='P')
                #  im.save('fig/{}_{}_{}.png'.format(image_id, ref_id, i))

                m = np.sum(
                    m, axis=2
                )  # sometimes there are multiple binary map (corresponding to multiple segs)
                m[m > 0] = 1

                ref_mask = m.astype(np.uint8)
                ref_mask = imresize(ref_mask,
                                    size=(blob.shape[1], blob.shape[2]),
                                    interp='nearest')
                ref_mask = np.expand_dims(ref_mask, axis=0)

                # pos ids
                pos_sent_ids = list(self.Refs[ref_id]['sent_ids'])  ####
                seq_num = len(pos_sent_ids)

                #pos_ann_ids = [ref_ann_id] * seq_per_ref
                pos_ref_ids = [ref_id] * seq_num  #### ++++
                pos_category_ids = [ref_category_id] * seq_num
                #pos_sent_ids = self.fetch_sent_ids_by_ref_id(ref_id, seq_per_ref) #### random choose 3 sentence for each reference object (bbox)

                #sent_ids = list(self.Refs[ref_id]['sent_ids']) ####
                #pos_sent_ids = [random.choice(sent_ids)] ####
                #print(ref_ids, pos_sent_ids)

                # neg ids
                #neg_ann_ids, neg_sent_ids = self.sample_neg_ids(ref_ann_id, seq_per_ref, sample_ratio)

                # add to image and batch
                #image_pos_ann_ids += pos_ann_ids
                #image_pos_ref_ids += pos_ref_ids #### ++++
                batch_pos_ref_ids += pos_ref_ids
                #image_neg_ann_ids += neg_ann_ids
                batch_pos_sent_ids += pos_sent_ids
                #batch_neg_sent_ids += neg_sent_ids
                batch_pos_category_ids += pos_category_ids
                for i in range(seq_num):
                    if first:
                        batch_mask = ref_mask
                        first = 0
                    else:
                        batch_mask = np.concatenate((batch_mask, ref_mask),
                                                    axis=0)

            # fetch feats
            #pos_ann_boxes = xywh_to_xyxy(np.vstack([self.Anns[ann_id]['box'] for ann_id in image_pos_ann_ids]))
            #pos_ref_boxes = xywh_to_xyxy(np.vstack([self.Refs[ref_id]['box'] for ref_id in image_pos_ref_ids])) #### ++++
            pos_ref_boxes = xywh_to_xyxy(
                np.vstack([
                    self.Refs[ref_id]['box'] for ref_id in batch_pos_ref_ids
                ]))
            #print('----------------------')
            #print(pos_ann_boxes)
            #print(pos_ref_boxes) #### ++++

            #image_pos_pool5, image_pos_fc7 = self.fetch_grid_feats(pos_ann_boxes, head, im_info)  # (num_pos, k, 7, 7)
            #image_pos_C4_feat = head.repeat(len(pos_ann_boxes),1,1,1) ####

            #batch_pos_pool5 += [image_pos_pool5]
            #batch_pos_fc7   += [image_pos_fc7]
            #batch_pos_C4_feat  += [image_pos_C4_feat]

            #neg_ann_boxes = xywh_to_xyxy(np.vstack([self.Anns[ann_id]['box'] for ann_id in image_neg_ann_ids]))
            #image_neg_pool5, image_neg_fc7 = self.fetch_grid_feats(neg_ann_boxes, head, im_info)  # (num_neg, k, 7, 7)
            #image_neg_C4_feat = head.repeat(len(neg_ann_boxes),1,1,1) ####

            #batch_neg_pool5 += [image_neg_pool5]
            #batch_neg_fc7   += [image_neg_fc7]
            #batch_neg_C4_feat  += [image_neg_C4_feat]

            # add to batch
            #batch_pos_ann_ids += image_pos_ann_ids
            #batch_neg_ann_ids += image_neg_ann_ids

        # get feats and labels
        #pos_C4_feat  = torch.cat(batch_pos_C4_feat, 0); pos_C4_feat.detach()
        #pos_fc7   = torch.cat(batch_pos_fc7, 0); pos_fc7.detach()
        #pos_pool5 = torch.cat(batch_pos_pool5, 0); pos_pool5.detach()
        #pos_lfeats = self.compute_lfeats(batch_pos_ann_ids)
        #pos_dif_lfeats = self.compute_dif_lfeats(batch_pos_ann_ids)
        pos_labels = np.vstack(
            [self.fetch_seq(sent_id) for sent_id in batch_pos_sent_ids])
        #neg_C4_feat  = torch.cat(batch_neg_C4_feat, 0); neg_C4_feat.detach()
        #neg_fc7   = torch.cat(batch_neg_fc7, 0); neg_fc7.detach()
        #neg_pool5 = torch.cat(batch_neg_pool5, 0); neg_pool5.detach()
        #neg_lfeats = self.compute_lfeats(batch_neg_ann_ids)
        #neg_dif_lfeats = self.compute_dif_lfeats(batch_neg_ann_ids)
        #neg_labels = np.vstack([self.fetch_seq(sent_id) for sent_id in batch_neg_sent_ids])

        # fetch cxt_fc7 and cxt_lfeats
        #pos_cxt_fc7, pos_cxt_lfeats, pos_cxt_ann_ids = self.fetch_cxt_feats(batch_pos_ann_ids, opt)
        #neg_cxt_fc7, neg_cxt_lfeats, neg_cxt_ann_ids = self.fetch_cxt_feats(batch_neg_ann_ids, opt)
        #pos_cxt_fc7 = Variable(torch.from_numpy(pos_cxt_fc7).cuda())
        #pos_cxt_lfeats = Variable(torch.from_numpy(pos_cxt_lfeats).cuda())
        #neg_cxt_fc7 = Variable(torch.from_numpy(neg_cxt_fc7).cuda())
        #neg_cxt_lfeats = Variable(torch.from_numpy(neg_cxt_lfeats).cuda())

        # fetch attributes for batch_pos_ann_ids ONLY
        #att_labels, select_ixs = self.fetch_attribute_label(batch_pos_ann_ids)

        # convert to Variable
        #pos_lfeats = Variable(torch.from_numpy(pos_lfeats).cuda())
        #pos_dif_lfeats = Variable(torch.from_numpy(pos_dif_lfeats).cuda())
        pos_labels = Variable(torch.from_numpy(pos_labels).long().cuda())
        #neg_lfeats = Variable(torch.from_numpy(neg_lfeats).cuda())
        #neg_dif_lfeats = Variable(torch.from_numpy(neg_dif_lfeats).cuda())
        #neg_labels = Variable(torch.from_numpy(neg_labels).long().cuda())

        # chunk pos_labels and neg_labels using max_len
        #max_len = max((pos_labels != 0).sum(1).max().data[0],
        #              (neg_labels != 0).sum(1).max().data[0])
        max_len = (pos_labels != 0).sum(1).max().data[0]
        pos_labels = pos_labels[:, :max_len]
        #neg_labels = neg_labels[:, :max_len]

        # return
        data = {}
        data['data'] = blob
        data['im_info'] = np.array(
            [[blob.shape[1], blob.shape[2], im_scales[0]]]).astype(np.float32)

        pos_ref_boxes = np.concatenate(
            (pos_ref_boxes * im_scales[0], np.array([batch_pos_category_ids
                                                     ]).T),
            axis=1)
        data['gt_boxes'] = pos_ref_boxes.astype(np.float32)
        data['gt_masks'] = batch_mask

        #data['ref_ann_ids'] = batch_pos_ann_ids
        #data['ref_sent_ids'] = batch_pos_sent_ids
        #data['ref_cxt_ann_ids'] = pos_cxt_ann_ids
        #data['Feats'] = {'C4_feat': pos_C4_feat, 'fc7': pos_fc7, 'pool5': pos_pool5, 'lfeats': pos_lfeats, 'dif_lfeats': pos_dif_lfeats,
        #                 'cxt_fc7': pos_cxt_fc7, 'cxt_lfeats': pos_cxt_lfeats}
        data['labels'] = pos_labels
        data['file_name'] = file_name

        #data['neg_ann_ids'] = batch_neg_ann_ids
        #data['neg_sent_ids'] = batch_neg_sent_ids
        #data['neg_Feats'] = {'C4_feat': neg_C4_feat, 'fc7': neg_fc7, 'pool5': neg_pool5, 'lfeats': neg_lfeats, 'dif_lfeats': neg_dif_lfeats,
        #                     'cxt_fc7': neg_cxt_fc7, 'cxt_lfeats': neg_cxt_lfeats}
        #data['neg_labels'] = neg_labels
        #data['neg_cxt_ann_ids'] = neg_cxt_ann_ids
        #data['att_labels'] = att_labels  # (num_pos_ann_ids, num_atts)
        #data['select_ixs'] = select_ixs  # variable size
        #data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': max_index, 'wrapped': wrapped}
        return data
Esempio n. 8
0
 def showRef(self, ref, seg_box='seg'):
     ax = plt.gca()
     # show image
     image = self.Imgs[ref['image_id']]
     I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
     ax.imshow(I)
     # show refer expression
     for sid, sent in enumerate(ref['sentences']):
         if self.data['dataset'] != 'refgta':
             print('%s. %s' % (sid + 1, sent['sent']))
         else:
             print('%s. %s' % (sid + 1, sent['sent']))
             print('[Acc]:{:.2f}%, [time] median:{:.2f},mean:{:.2f}'.format(
                 100 * np.mean([o['if_true'] for o in sent['info']]),
                 np.median([1e-3 * o['time'] for o in sent['info']]),
                 np.mean(
                     sorted([1e-3 * o['time']
                             for o in sent['info']])[1:4])))
     # show segmentations
     if seg_box == 'seg':
         assert self.data['dataset'] != 'refgta', print(
             'segmentation is not supported for refgta')
         ann_id = ref['ann_id']
         ann = self.Anns[ann_id]
         polygons = []
         color = []
         c = 'none'
         if type(ann['segmentation'][0]) == list:
             # polygon used for refcoco*
             for seg in ann['segmentation']:
                 poly = np.array(seg).reshape((len(seg) // 2, 2))
                 polygons.append(Polygon(poly, True, alpha=0.4))
                 color.append(c)
             p = PatchCollection(polygons,
                                 facecolors=color,
                                 edgecolors=(1, 1, 0, 0),
                                 linewidths=3,
                                 alpha=1)
             ax.add_collection(p)  # thick yellow polygon
             p = PatchCollection(polygons,
                                 facecolors=color,
                                 edgecolors=(1, 0, 0, 0),
                                 linewidths=1,
                                 alpha=1)
             ax.add_collection(p)  # thin red polygon
         else:
             # mask used for refclef
             rle = ann['segmentation']
             m = mask.decode(rle)
             img = np.ones((m.shape[0], m.shape[1], 3))
             color_mask = np.array([2.0, 166.0, 101.0]) / 255
             for i in range(3):
                 img[:, :, i] = color_mask[i]
             ax.imshow(np.dstack((img, m * 0.5)))
     # show bounding-box
     elif seg_box == 'box':
         ann_id = ref['ann_id']
         bbox = self.getRefBox(ref['ref_id'])
         box_plot = Rectangle((bbox[0], bbox[1]),
                              bbox[2],
                              bbox[3],
                              fill=False,
                              edgecolor='red',
                              linewidth=3)
         ax.add_patch(box_plot)
         for others in self.imgToAnns[ref['image_id']]:
             if others['id'] != ann_id:
                 bbox = others['bbox']
                 box_plot = Rectangle((bbox[0], bbox[1]),
                                      bbox[2],
                                      bbox[3],
                                      fill=False,
                                      edgecolor='blue',
                                      linewidth=3)
                 ax.add_patch(box_plot)
Esempio n. 9
0
    def getTestBatch(self, split, batch_size=1):  # revised from getBatch
        # options
        #print('batch_size:', batch_size)
        #batch_size = batch_size #opt.get('batch_size', 5) ####
        #seq_per_ref = 1 #3 #opt.get('seq_per_ref', 3)
        #sample_ratio = 0.3 #opt.get('visual_sample_ratio', 0.3)  # sample ratio, st vs dt
        split_ix = self.split_ix[split]
        max_index = len(split_ix) - 1  # don't forget to -1
        wrapped = False

        # fetch image_ids
        batch_image_ids = []
        for i in range(batch_size):
            ri = self.iterators[split]
            ri_next = ri + 1
            if ri_next > max_index:
                print('number of images in split {}: {}'.format(
                    split, len(split_ix)))
                self.perm[split] = np.random.permutation(len(split_ix))
                print('perm', split, 'shuffled:', self.perm[split])
                ri_next = 0
                wrapped = True
            self.iterators[split] = ri_next
            image_id = split_ix[self.perm[split][ri]]
            batch_image_ids += [image_id]

        # fetch feats
        #batch_ref_ids = []
        batch_pos_ref_ids = []
        batch_pos_sent_ids = []
        batch_pos_category_ids = []
        first = 1

        blob, im_scales = self._get_image_blob(batch_image_ids)

        for image_id in batch_image_ids:
            ref_ids = self.Images[image_id]['ref_ids']
            #ref_ids = [random.choice(ref_ids)] ####
            file_name = self.Images[image_id]['file_name']  ####
            for ref_id in ref_ids:
                #ref_ann_id = self.Refs[ref_id]['ann_id']
                ref_category_id = self.Refs[ref_id]['category_id']

                rle = self.Refs[ref_id]['rle']
                m = mask.decode(rle)

                m = np.sum(
                    m, axis=2
                )  # sometimes there are multiple binary map (corresponding to multiple segs)
                m[m > 0] = 1

                ref_mask = m.astype(np.uint8)
                ref_mask = imresize(ref_mask,
                                    size=(blob.shape[1], blob.shape[2]),
                                    interp='nearest')
                ref_mask = np.expand_dims(ref_mask, axis=0)

                # pos ids
                pos_sent_ids = list(self.Refs[ref_id]['sent_ids'])  ####
                seq_num = len(pos_sent_ids)

                pos_ref_ids = [ref_id]  #### ++++
                pos_category_ids = [ref_category_id] * seq_num

                batch_pos_ref_ids += pos_ref_ids
                #image_neg_ann_ids += neg_ann_ids
                batch_pos_sent_ids += pos_sent_ids
                #batch_neg_sent_ids += neg_sent_ids
                batch_pos_category_ids += pos_category_ids
                if first:
                    batch_mask = ref_mask
                    first = 0
                else:
                    batch_mask = np.concatenate((batch_mask, ref_mask), axis=0)

            # fetch feats
            pos_ref_boxes = xywh_to_xyxy(
                np.vstack([
                    self.Refs[ref_id]['box'] for ref_id in batch_pos_ref_ids
                ]))

        pos_labels = np.vstack(
            [self.fetch_seq(sent_id) for sent_id in batch_pos_sent_ids])

        max_len = (pos_labels != 0).sum(1).max()
        cap_labels = pos_labels[:, :max_len]

        label_batch = np.zeros([cap_labels.shape[0], cap_labels.shape[1] + 2],
                               dtype='int')
        mask_batch = np.zeros([cap_labels.shape[0], cap_labels.shape[1] + 2],
                              dtype='float32')

        label_batch[:, 1:-1] = cap_labels

        # generate mask
        nonzeros = np.array(
            list(map(lambda x: (x != 0).sum() + 2, label_batch)))
        for ix, row in enumerate(mask_batch):
            row[:nonzeros[ix]] = 1

        # convert to Variable
        pos_labels = Variable(torch.from_numpy(pos_labels).long().cuda())  ####

        # chunk pos_labels and neg_labels using max_len
        #max_len = max((pos_labels != 0).sum(1).max().data[0],
        #              (neg_labels != 0).sum(1).max().data[0])
        max_len = (pos_labels != 0).sum(1).max().data[0]  ####
        pos_labels = pos_labels[:, :max_len]  ####

        # return
        data = {}
        data['data'] = blob
        data['im_info'] = np.array(
            [[blob.shape[1], blob.shape[2], im_scales[0]]]).astype(np.float32)

        pos_ref_boxes = np.concatenate(
            (pos_ref_boxes * im_scales[0], np.array([batch_pos_category_ids
                                                     ]).T),
            axis=1)
        data['gt_boxes'] = pos_ref_boxes.astype(np.float32)
        data['gt_masks'] = batch_mask

        data['labels'] = pos_labels
        data['file_name'] = file_name
        data['ref_ids'] = batch_pos_ref_ids

        data['cap_labels'] = label_batch
        data['cap_masks'] = mask_batch

        data['bounds'] = {
            'it_pos_now': self.iterators[split],
            'it_max': max_index,
            'wrapped': wrapped
        }

        return data