Ejemplo n.º 1
0
################################
# # process the train question #
################################
taskType = 'OpenEnded'
dataType = 'mscoco'
dataSubType = 'train2014'
annFile = '%s/Annotations/%s_%s_annotations.json' % (dataDir, dataType,
                                                     dataSubType)
quesFile = '%s/Questions/%s_%s_%s_questions.json' % (dataDir, taskType,
                                                     dataType, dataSubType)
imgDir = '%s/Images/%s/' % (dataDir, dataSubType)

vqa = VQA(annFile, quesFile)
train_question_ids = vqa.getQuesIds()
train_image_ids = vqa.getImgIds()
train_questions = []
train_answers = []
question_dict_count = dict()
answer_dict_count = dict()

for idx, q_id in enumerate(train_question_ids):
    question = vqa.qqa[q_id]['question']
    question = process_sentence(question)
    question = question.split()
    for word in question:
        question_dict_count[word] = question_dict_count.get(word, 0) + 1
    answer = vqa.loadQA(q_id)[0]['answers'][0]['answer']
    answer_new = [process_answer(ans) for ans in answer]
    for word in answer_new:
        answer_dict_count[word] = answer_dict_count.get(word, 0) + 1
Ejemplo n.º 2
0
print 'finished processing features'

################################
# # process the train question #
################################
taskType    = 'OpenEnded'
dataType    = 'mscoco'
dataSubType = 'train2014'
annFile     = '%s/Annotations/%s_%s_annotations.json'%(dataDir, dataType, dataSubType)
quesFile    = '%s/Questions/%s_%s_%s_questions.json'%(dataDir, taskType, dataType, dataSubType)
imgDir      ='%s/Images/%s/' %(dataDir, dataSubType)

vqa = VQA(annFile, quesFile)
train_question_ids = vqa.getQuesIds()
train_image_ids = vqa.getImgIds()
train_questions = []
train_answers = []
question_dict_count = dict()
answer_dict_count = dict()

for idx, q_id in enumerate(train_question_ids):
    question = vqa.loadQuestion(q_id)[0]
    question = process_sentence(question)
    question = question.split()
    for word in question:
        question_dict_count[word] = question_dict_count.get(word, 0) + 1
    answer = vqa.loadAnswer(q_id)[0]
    answer_new = [process_answer(ans) for ans in answer]
    for word in answer_new:
        answer_dict_count[word] = answer_dict_count.get(word, 0) + 1
Ejemplo n.º 3
0
"""
annIds = vqa.getQuesIds(ansTypes='yes/no')
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_' + str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
    I = io.imread(imgDir + imgFilename)
    plt.imshow(I)
    plt.axis('off')
    plt.show()

# load and display QA annotations for given images
"""
Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
"""
ids = vqa.getImgIds()
annIds = vqa.getQuesIds(imgIds=random.sample(ids, 5))
anns = vqa.loadQA(annIds)
randomAnn = random.choice(anns)
vqa.showQA([randomAnn])
imgId = randomAnn['image_id']
imgFilename = 'COCO_' + dataSubType + '_' + str(imgId).zfill(12) + '.jpg'
if os.path.isfile(imgDir + imgFilename):
    I = io.imread(imgDir + imgFilename)
    plt.imshow(I)
    plt.axis('off')
    plt.show()
class VqaDataset(Dataset):
    def __init__(self,
                 image_dir,
                 question_json_file_path,
                 annotation_json_file_path,
                 image_filename_pattern,
                 collate=False,
                 q2i=None,
                 a2i=None,
                 i2a=None,
                 a2i_count=None,
                 img_names=None,
                 img_ids=None,
                 ques_ids=None,
                 method='simple',
                 dataset_type='train',
                 enc_dir=''):
        print(method)
        self.image_dir = image_dir
        self.qjson = question_json_file_path
        self.ajson = annotation_json_file_path
        img_prefix = image_filename_pattern.split('{}')[0]
        self.collate = collate
        self.q2i = q2i
        self.a2i = a2i
        self.i2a = i2a
        self.a2i_count = a2i_count
        self.img_ids = img_ids
        self.ques_ids = ques_ids
        self.img_names = img_names
        self.method = method
        self.vqa = VQA(self.ajson, self.qjson)

        if self.method == 'simple':
            self.transform = transforms.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()])
        else:
            self.transform = transforms.Compose(
                [transforms.Resize((448, 448)),
                 transforms.ToTensor()])

        if not collate:
            self.img_names = [
                f for f in os.listdir(self.image_dir) if '.jpg' in f
            ]
            #print(self.img_names)
            self.img_ids = []
            for fname in self.img_names:
                img_id = fname.split('.')[0].rpartition(img_prefix)[-1]
                self.img_ids.append(int(img_id))

            self.ques_ids = self.vqa.getQuesIds(self.img_ids)

            self.q2i, self.a2i, self.i2a, self.a2i_count = pre_process_dataset(
                image_dir, self.qjson, self.ajson, img_prefix)
        #print("Image names ",self.img_names)
        self.q2i_len = len(self.q2i)
        self.a2i_len = len(self.a2i.keys())
        self.q2i_keys = self.q2i.keys()
        self.enc_dir = enc_dir

    def __len__(self):
        return len(self.ques_ids)

    def __getitem__(self, idx):
        ques_id = self.ques_ids[idx]
        img_id = self.vqa.getImgIds([ques_id])[0]

        qa = self.vqa.loadQA(ques_id)[0]
        qqa = self.vqa.loadQQA(ques_id)[0]
        img_name = self.img_names[self.img_ids.index(img_id)]

        if self.method == 'simple':
            img = default_loader(self.image_dir + '/' + img_name)
            #imgT = self.transform(img).permute(1, 2, 0)
            imgT = self.transform(img).float()
        else:
            img = default_loader(self.image_dir + '/' + img_name)
            imgT = self.transform(img).float()

        ques = qqa['question'][:-1]
        quesI = [self.q2i["<sos>"]] + [
            self.q2i[x.lower()]
            for x in ques.split(" ") if x.lower() in self.q2i_keys
        ] + [self.q2i["<eos>"]]
        if not self.collate:
            quesI = quesI + [self.q2i["<pad>"]] * (8 - len(quesI))
        if self.method == 'simple':
            quesT = torch.zeros(self.q2i_len).float()
            for idx in quesI:
                quesT[idx] = 1
        else:
            quesT = torch.from_numpy(np.array(quesI)).long()

        answers = qa['answers']
        max_count = 0
        answer = ""
        for ans in answers:
            #if not ans['answer_confidence'] == 'yes':
            #    continue
            ans = ans['answer'].lower()
            if ans in self.a2i.keys() and self.a2i_count[ans] > max_count:
                max_count = self.a2i_count[ans]
                answer = ans

        if answer == "":  # only for validation
            gT = torch.from_numpy(np.array([self.a2i_len])).long()
        else:
            gT = torch.from_numpy(np.array([self.a2i[answer]])).long()

        if not self.collate:
            return {'img': imgT, 'ques': quesT, 'gt': gT}

        return imgT, quesT, gT, img_name, ques, answer