Exemplo n.º 1
0
class RefDataset(Dataset):
    def __init__(self, split):
        self.refer = REFER(dataset='refcoco+', splitBy='unc')
        self.ref_ids = self.refer.getRefIds(split=split)

        self.image_embeds = np.load(
            os.path.join("data", "embeddings", "FINALImageEmbeddings.npy"))
        self.image_ids = list(
            np.load(os.path.join("data", "embeddings", "FINALImageIDs.npy")))
        before_text_embeds = time.time()
        self.text_embeds = np.concatenate(
            (np.load(
                os.path.join("data", "embeddings",
                             "FINALTextEmbeddings1of2.npy")),
             np.load(
                 os.path.join("data", "embeddings",
                              "FINALTextEmbeddings2of2.npy"))),
            axis=0)
        after_text_embeds = time.time()
        print("Text Embedding Time: ", after_text_embeds - before_text_embeds)
        assert (len(self.text_embeds) == 141564)
        assert (self.text_embeds[0].shape[1] == 3072)
        print('Found {} referred objects in {} split.'.format(
            len(self.ref_ids), split))

    def __len__(self):
        return len(self.ref_ids)

    def __getitem__(self, i):
        ref_id = self.ref_ids[i]
        ref = self.refer.loadRefs(ref_id)[0]

        image_id = ref['image_id']
        image = self.refer.Imgs[image_id]
        image_idx = self.image_ids.index(image_id)
        image_embed = self.image_embeds[image_idx, :, :, :]

        height = image['height']
        width = image['width']
        bound_box = torch.Tensor(self.refer.getRefBox(ref_id))
        bound_box[0] /= width
        bound_box[1] /= height
        bound_box[2] /= width
        bound_box[3] /= height
        #bound_box = bound_box.unsqueeze(dim=0)

        #whole_file_name = ref['file_name']
        #file_name = whole_file_name[:whole_file_name.rfind("_")]+".jpg"

        sent = random.choice(ref['sentences'])
        ref_expr = sent['raw']
        text_id = sent['sent_id']

        text_idx = text_id
        text_embed = torch.from_numpy(self.text_embeds[text_idx])

        return image_embed, text_embed, bound_box
dataset = 'refcoco'
splitBy = 'unc'
refer = REFER(data_root, dataset, splitBy)

ref_ids = refer.getRefIds(split='testB')
images_dir = '/root/refer/data/images/mscoco/images/train2014/'

hyp = open("hyp.txt", "w")
ref1 = open("ref1.txt", "w")
ref2 = open("ref2.txt", "w")
ref3 = open("ref3.txt", "w")
ref4 = open("ref4.txt", "w")

for ref_id in tqdm(ref_ids):
    ref = refer.Refs[ref_id]
    x, y, w, h = refer.getRefBox(ref_id)  # [x, y, w, h]
    x1, y1, x2, y2 = x, y, x + w, y + h
    image_path = images_dir + refer.Imgs[ref['image_id']]['file_name']

    image = scipy.misc.imread(image_path)
    if len(image.shape) != 3:
        continue

    # Run detection
    results = model.detect([image], verbose=0)

    # Visualize results
    r = results[0]
    #     visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
    #                                 class_names, r['scores'])