def pre_process_dataset(image_dir, qjson, ajson, img_prefix): print('Preprocessing datatset. \n') vqa = VQA(ajson, qjson) img_names = [f for f in os.listdir(image_dir) if '.jpg' in f] img_names = img_names[:30000] print("length: ", len(img_names)) img_ids = [] for fname in img_names: img_id = fname.split('.')[0].rpartition(img_prefix)[-1] img_ids.append(int(img_id)) print("Done collecting image ids") ques_ids = vqa.getQuesIds(img_ids) q2i = defaultdict(lambda: len(q2i)) pad = q2i["<pad>"] start = q2i["<sos>"] end = q2i["<eos>"] UNK = q2i["<unk>"] a2i_count = {} for ques_id in ques_ids: qa = vqa.loadQA(ques_id)[0] qqa = vqa.loadQQA(ques_id)[0] ques = qqa['question'][:-1] [q2i[x] for x in ques.lower().strip().split(" ")] answers = qa['answers'] for ans in answers: if not ans['answer_confidence'] == 'yes': continue ans = ans['answer'].lower() if ans not in a2i_count: a2i_count[ans] = 1 else: a2i_count[ans] = a2i_count[ans] + 1 print("Done collecting Q/A") a_sort = sorted(a2i_count.items(), key=operator.itemgetter(1), reverse=True) i2a = {} count = 0 a2i = defaultdict(lambda: len(a2i)) for word, _ in a_sort: a2i[word] i2a[a2i[word]] = word count = count + 1 if count == 1000: break print("Done collecting words") return q2i, a2i, i2a, a2i_count
for fname in img_names: img_id = fname.split('.')[0].rpartition(img_prefix)[-1] img_ids.append(int(img_id)) ques_ids = vqa.getQuesIds(img_ids) q2i = defaultdict(lambda: len(q2i)) pad = q2i["<pad>"] start = q2i["<sos>"] end = q2i["<eos>"] UNK = q2i["<unk>"] a2i_count = {} for ques_id in ques_ids: qa = vqa.loadQA(ques_id)[0] qqa = vqa.loadQQA(ques_id)[0] ques = qqa['question'][:-1] [q2i[x] for x in ques.lower().strip().split(" ")] answers = qa['answers'] for ans in answers: if not ans['answer_confidence'] == 'yes': continue ans = ans['answer'].lower() if ans not in a2i_count: a2i_count[ans] = 1 else: a2i_count[ans] = a2i_count[ans] + 1 a_sort = sorted(a2i_count.items(), key=operator.itemgetter(1), reverse=True)
class VqaDataset(Dataset): def __init__(self, image_dir, question_json_file_path, annotation_json_file_path, image_filename_pattern, collate=False, q2i=None, a2i=None, i2a=None, a2i_count=None, img_names=None, img_ids=None, ques_ids=None, method='simple', dataset_type='train', enc_dir=''): print(method) self.image_dir = image_dir self.qjson = question_json_file_path self.ajson = annotation_json_file_path img_prefix = image_filename_pattern.split('{}')[0] self.collate = collate self.q2i = q2i self.a2i = a2i self.i2a = i2a self.a2i_count = a2i_count self.img_ids = img_ids self.ques_ids = ques_ids self.img_names = img_names self.method = method self.vqa = VQA(self.ajson, self.qjson) if self.method == 'simple': self.transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor()]) else: self.transform = transforms.Compose( [transforms.Resize((448, 448)), transforms.ToTensor()]) if not collate: self.img_names = [ f for f in os.listdir(self.image_dir) if '.jpg' in f ] #print(self.img_names) self.img_ids = [] for fname in self.img_names: img_id = fname.split('.')[0].rpartition(img_prefix)[-1] self.img_ids.append(int(img_id)) self.ques_ids = self.vqa.getQuesIds(self.img_ids) self.q2i, self.a2i, self.i2a, self.a2i_count = pre_process_dataset( image_dir, self.qjson, self.ajson, img_prefix) #print("Image names ",self.img_names) self.q2i_len = len(self.q2i) self.a2i_len = len(self.a2i.keys()) self.q2i_keys = self.q2i.keys() self.enc_dir = enc_dir def __len__(self): return len(self.ques_ids) def __getitem__(self, idx): ques_id = self.ques_ids[idx] img_id = self.vqa.getImgIds([ques_id])[0] qa = self.vqa.loadQA(ques_id)[0] qqa = self.vqa.loadQQA(ques_id)[0] img_name = self.img_names[self.img_ids.index(img_id)] if self.method == 'simple': img = default_loader(self.image_dir + '/' + img_name) #imgT = self.transform(img).permute(1, 2, 0) imgT = self.transform(img).float() else: img = default_loader(self.image_dir + '/' + img_name) imgT = self.transform(img).float() ques = qqa['question'][:-1] quesI = [self.q2i["<sos>"]] + [ self.q2i[x.lower()] for x in ques.split(" ") if x.lower() in self.q2i_keys ] + [self.q2i["<eos>"]] if not self.collate: quesI = quesI + [self.q2i["<pad>"]] * (8 - len(quesI)) if self.method == 'simple': quesT = torch.zeros(self.q2i_len).float() for idx in quesI: quesT[idx] = 1 else: quesT = torch.from_numpy(np.array(quesI)).long() answers = qa['answers'] max_count = 0 answer = "" for ans in answers: #if not ans['answer_confidence'] == 'yes': # continue ans = ans['answer'].lower() if ans in self.a2i.keys() and self.a2i_count[ans] > max_count: max_count = self.a2i_count[ans] answer = ans if answer == "": # only for validation gT = torch.from_numpy(np.array([self.a2i_len])).long() else: gT = torch.from_numpy(np.array([self.a2i[answer]])).long() if not self.collate: return {'img': imgT, 'ques': quesT, 'gt': gT} return imgT, quesT, gT, img_name, ques, answer