def add_lvis_image_feature(self, question):
     try:
         if "COCO" in question["image"]:
             img_name = question["image"].split("/")[1] + ".pth"
             path = os.path.join(self.dir_coco_lvis, img_name)
         elif "VG" in question["image"]:
             # VG_100K/4.jpg
             img_name = question["image"].split("/")[1]
             path = os.path.join(self.dir_vg_lvis, img_name) + ".pth"
         features = torch.load(path)
         question["visual_lvis"] = features["pooled_feat"]
         question["coord_lvis"] = features["rois"]
         norm_rois = features.get("norm_rois_lvis", None)
         if norm_rois is None:
             rois = question["coord_lvis"]
             rois_min = 0
             rois_max, _ = rois.max(dim=0)
             # between 0 and 1
             question["norm_coord_lvis"] = question["coord_lvis"] / rois_max
         else:
             question["norm_coord_lvis"] = norm_rois
         question["nb_regions_lvis"] = question["visual_lvis"].size(0)
     except FileNotFoundError:
         Logger()(
             f"Missing LVIS features for image {question['image']}",
             log_level=Logger.ERROR,
         )
         question["visual_lvis"] = torch.zeros(100, 1024)
         question["coord_lvis"] = torch.zeros(100, 4)
         question["norm_coord_lvis"] = torch.zeros(100, 4)
         question["nb_regions_lvis"] = 0
     return question
Exemple #2
0
 def stack_tensors(self, batch, key=None):
     if isinstance(batch, collections.Mapping):
         out = {}
         for key, value in batch.items():
             if key not in self.avoid_keys:
                 out[key] = self.stack_tensors(value, key=key)
             else:
                 out[key] = value
         return out
     elif isinstance(batch, tuple):
         return tuple(self.stack_tensors(elt) for elt in batch)
     elif isinstance(batch, collections.Sequence) and torch.is_tensor(
             batch[0]):
         out = None
         if self.use_shared_memory:
             # If we're in a background process, concatenate directly into a
             # shared memory tensor to avoid an extra copy
             numel = sum([x.numel() for x in batch])
             storage = batch[0].storage()._new_shared(numel)
             out = batch[0].new(storage)
         try:
             return torch.stack(batch, 0, out=out)
         except:
             traceback.print_exc()
             Logger()(f"Failed stacking tensor with key {key}",
                      log_level=Logger.ERROR)
             import ipdb
             ipdb.set_trace()
     else:
         return batch
    def __getitem__(self, index):
        # get image
        question = deepcopy(self.questions[index])
        question["index"] = index
        full_question = question["question"]
        # question tokens
        tokens = tokenize_mcb(full_question.strip())
        # token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        question["question_tokens"] = tokens
        question["raw_question"] = full_question
        # pad question

        question["question"] = torch.tensor(
            [self.word_to_wid[token] for token in tokens]
        )
        question["lengths"] = torch.tensor([len(tokens)])
        question["class_id"] = torch.tensor([question["answer"]])
        question["label"] = {str(question["answer"]): 1}  # for lxmert
        question["target"] = torch.zeros(len(self.aid_to_ans))
        question["target"][question["answer"]] = 1.0
        question["original_question"] = question["raw_question"]

        if self.no_features:
            return question
        try:
            if self.image_features == "default":
                question = self.add_image_features(question)
            elif self.image_features == "default+lvis":
                question = self.add_image_features(question)
                question = self.add_lvis_image_feature(question)
            elif self.image_features == "lxmert":
                question = self.add_lxmert_image_features(question)
            elif self.image_features == "resnet":
                question = self.add_resnet_image_features(question)
            elif self.image_features == "vilbert":
                question = self.add_vilbert_image_features(question)
            if self.background:
                question = self.add_resnet_image_features(question, key="background_")
                # breakpoint()

        except FileNotFoundError:
            Logger()(
                f"Missing image {question['image']} of question {question['question_id']}"
            )
            question["question"] = torch.tensor([0])
            question["class_id"] = torch.tensor([0])
            question["visual"] = torch.zeros(100, 2048)
            question["coord"] = torch.zeros(100, 4)
            question["norm_coord"] = torch.zeros(100, 4)
            question["nb_regions"] = 100

        return question
    def get_token_list(self):
        Logger()("Getting tokens list")

        tokens = set()
        for split in ["train", "test"]:
            q_path = os.path.join(self.dir_data, f"{split}.json")  # train or test
            with open(q_path) as f:
                questions = json.load(f)
            for q in tqdm(questions):
                tokens = tokens.union(tokenize_mcb(q["question"]))
        # breakpoint()
        tokens = list(tokens)
        tokens = tokens
        return tokens
    def __init__(
        self,
        loss="mse",
        entropy_loss_weight=0.0,
    ):

        """
        entropy loss: term by term entropy
        """
        super().__init__()
        if loss == "mse":
            self.loss = nn.MSELoss()
        elif loss == "huber":
            self.loss = nn.SmoothL1Loss()
        else:
            raise ValueError(loss)

        self.entropy_loss_weight = entropy_loss_weight
        Logger()(f"entropy_loss_weight={entropy_loss_weight}")
    def __init__(
        self, *args, proportion_opposite=0.0, train_selection="even", **kwargs
    ):

        """
        proportion_opposite: the proportion of odd (resp even) questions in train (resp test) set.
            (if train = even, opposite if train = odd)
        """
        super().__init__(*args, **kwargs)
        self.proportion_opposite = proportion_opposite
        assert 0.0 <= proportion_opposite <= 1.0
        assert train_selection in ("even", "odd")
        self.train_selection = train_selection
        self.proportion_opposite = proportion_opposite

        print("************", self.get_path_questions_even_odd())
        if not os.path.exists(self.get_path_questions_even_odd()):
            self.process_even_odd()

        if "train" in self.split or "val" in self.split:
            mode = "train"
        else:
            mode = "test"

        if (mode, self.train_selection) in [("train", "even"), ("test", "odd")]:
            self.own_numbers = [0, 2, 4, 6, 8, 10, 12, 14]
            self.opposite_numbers = [1, 3, 5, 7, 9, 11, 13, 15]
        elif (mode, self.train_selection) in [("train", "odd"), ("test", "even")]:
            self.own_numbers = [1, 3, 5, 7, 9, 11, 13, 15]
            self.opposite_numbers = [0, 2, 4, 6, 8, 10, 12, 14]
        else:
            raise ValueError((mode, self.train_selection))

        with open(self.get_path_questions_even_odd()) as f:
            self.questions = json.load(f)
        Logger()(f"Number of questions in split {self.split}: {len(self.questions)}")
    def compute_accuracy(self):
        accs = {}
        rmses = {}
        for t in self.answers.keys():
            # for t in ("simple", "complex", "overall", "overall_same", "overall_diff", "overall_same_diff"):
            if len(self.answers[t]["gt"]) > 0:
                gt = torch.tensor(self.answers[t]["gt"]).long()
                if len(self.answers[t]["pred"]) > 0:
                    pred = torch.tensor(self.answers[t]["pred"])
                    acc = (pred.round().long() == gt).sum().float() / len(pred)
                else:
                    pred = torch.tensor(self.answers[t]["ans"])
                    acc = (pred.long() == gt).sum().float() / len(pred)
                accs[t] = acc
                Logger().log_value(
                    f"{self.mode}_epoch.tally_acc.{t}",
                    acc.item() * 100,
                    should_print=True,
                )

                # compute L1, L2, RMSE
                diff = (pred - gt).float()
                l1 = torch.abs(diff).mean()
                Logger().log_value(f"{self.mode}_epoch.tally_l1.{t}",
                                   l1.item())
                l2 = (diff**2).mean()
                Logger().log_value(f"{self.mode}_epoch.tally_l2.{t}",
                                   l2.item())
                rmse = torch.sqrt(l2)
                rmses[t] = rmse.item()
                Logger().log_value(f"{self.mode}_epoch.tally_rmse.{t}",
                                   rmse.item(),
                                   should_print=True)

                # compute score with hard threshold
                if len(self.answers[t]["hard.ans"]) > 0:
                    hard_pred = torch.tensor(
                        self.answers[t]["hard.ans"]).long()
                    acc = (hard_pred == gt).sum().float() / len(hard_pred)
                    Logger().log_value(
                        f"{self.mode}_epoch.tally_thresh_acc.{t}",
                        acc.item() * 100,
                        should_print=True,
                    )

        # compute mean acc complex and simple
        if "simple" in accs and "complex" in accs:
            mean_acc = (accs["simple"].item() + accs["complex"].item()) / 2
            Logger().log_value(
                f"{self.mode}_epoch.tally_acc.mean_complex_simple",
                mean_acc * 100,
                should_print=True,
            )

        # compute normalized accuracy
        for t in ("overall", "simple", "complex", "positional", "type"):

            accs_t = [
                accs[f"{t}-{number}"] for number in range(16)
                if f"{t}-{number}" in accs
            ]
            if accs_t:
                # normalized arithmetic accuracy
                m_rel_acc = np.mean(accs_t)
                Logger().log_value(
                    f"{self.mode}_epoch.tally_acc.m-rel.{t}",
                    m_rel_acc.item() * 100,
                    should_print=True,
                )

            # normalized harmonic accuracy
            try:
                normalized_harmonic_acc = scipy.stats.hmean(accs_t).item()
            except ValueError:
                # there is a zero value in list
                normalized_harmonic_acc = 0.0
            Logger().log_value(
                f"{self.mode}_epoch.tally_acc.norm_harmonic.{t}",
                normalized_harmonic_acc * 100,
            )

            # normalized RMSE
            rmses_t = [
                rmses[f"{t}-{number}"] for number in range(16)
                if f"{t}-{number}" in rmses
            ]
            if rmses_t:
                m_rel_rmse = np.mean(rmses_t)
                Logger().log_value(f"{self.mode}_epoch.tally_rmse.m-rel.{t}",
                                   m_rel_rmse.item())

            # normalized
            try:
                normalized_harmonic_rmse = scipy.stats.hmean(rmses_t).item()
            except ValueError:
                # zero value (should not happen for regression)
                normalized_harmonic_rmse = 0
            Logger().log_value(
                f"{self.mode}_epoch.tally_rmse.norm_harmonic.{t}",
                normalized_harmonic_rmse,
            )

        # Metrics for COCO grounding
        if self.all_ious:
            all_ious = pd.DataFrame(self.all_ious, columns=IOU_COLUMNS)
            exp_dir = self.exp_dir
            threshold = self.score_threshold_grounding
            iou_path = os.path.join(exp_dir, f"iou-{threshold}.pickle")
            all_ious.to_pickle(iou_path)
            # log other metrics
            columns_to_log = [
                "iou",
                "iou_sum",
                "ioo",
                "ioo_sum",
                "iogt",
                "iogt_sum",
                "iou_boxes_norm",
                "ioo_boxes_norm",
                "iogt_boxes_norm",
            ]
            iou_nonzero = all_ious[all_ious["answer"] != 0]
            for metric in columns_to_log:
                # overall
                m = all_ious[metric].mean()
                Logger().log_value(f"{self.mode}_epoch.{metric}.overall",
                                   m,
                                   should_print=True)
                m = iou_nonzero[metric].mean()
                Logger().log_value(
                    f"{self.mode}_epoch.{metric}.overall.nonzero",
                    m,
                    should_print=True)
                # by number
                # by object
                for i in range(16):
                    m = all_ious[all_ious["answer"] == i][metric].mean()
                    Logger().log_value(f"{self.mode}_epoch.{metric}.{i}",
                                       m,
                                       should_print=True)

                # by object
                for name in all_ious.name.unique():
                    m = all_ious[all_ious.name == name][metric].mean()
                    Logger().log_value(f"{self.mode}_epoch.{metric}.{name}",
                                       m,
                                       should_print=False)
                    m = iou_nonzero[iou_nonzero.name == name][metric].mean()
                    Logger().log_value(
                        f"{self.mode}_epoch.{metric}.{name}.nonzero",
                        m,
                        should_print=False,
                    )

            # additionally, do global normalization
            for metric in ["iou_boxes", "ioo_boxes", "iogt_boxes"]:
                m = all_ious[metric].sum() / all_ious["pred"].sum()
                Logger().log_value(f"{self.mode}_epoch.{metric}.overall",
                                   m,
                                   should_print=True)
                m = iou_nonzero[metric].sum() / iou_nonzero["pred"].sum()
                Logger().log_value(
                    f"{self.mode}_epoch.{metric}.overall.nonzero",
                    m,
                    should_print=True)

                for i in range(16):
                    all_ious_i = all_ious[all_ious["answer"] == i]
                    m = all_ious_i[metric].sum() / all_ious_i["pred"].sum()
                    Logger().log_value(f"{self.mode}_epoch.{metric}.{i}",
                                       m,
                                       should_print=True)

                for name in all_ious.name.unique():
                    all_ious_name = all_ious[all_ious.name == name]
                    m = all_ious_name[metric].sum(
                    ) / all_ious_name["pred"].sum()
                    Logger().log_value(f"{self.mode}_epoch.{metric}.{name}",
                                       m,
                                       should_print=False)
                    ious_nz_name = iou_nonzero[iou_nonzero.name == name]
                    m = ious_nz_name[metric].sum() / ious_nz_name["pred"].sum()
                    Logger().log_value(
                        f"{self.mode}_epoch.{metric}.{name}.nonzero",
                        m,
                        should_print=False,
                    )

            # mean average precision
            for t in THRESHOLDS_mAP:
                scores = self.mean_ap[t].get_scores()
                # breakpoint()
                Logger().log_value(f"{self.mode}_epoch.mAP.{t}.overall",
                                   scores,
                                   should_print=True)
    def forward(self, cri_out, net_out, batch):
        out = {}
        logits = net_out["logits"].data.cpu()
        class_id = batch["class_id"]
        acc_out = accuracy(logits, class_id.data.cpu(), topk=self.topk)

        for i, k in enumerate(self.topk):
            out["accuracy_top{}".format(k)] = acc_out[i]

        # compute accuracy on simple and difficult examples
        answers = torch.argmax(logits, dim=1)

        for i in range(len(net_out["logits"])):
            pred = answers[i].item()
            gt = batch["answer"][i]

            categories = {"overall"}
            if "issimple" in batch and batch["issimple"][i]:
                main_cat = "simple"
                categories.add("simple")
            elif "issimple" in batch and not batch["issimple"][i]:
                main_cat = "complex"
                categories.add("complex")
            else:
                main_cat = None

            # add categories per number
            categories.add(f"overall-{gt}")
            # if "simple" in categories:
            if main_cat is not None:
                categories.add(f"{main_cat}-{gt}")
            # elif "complex" in categories:
            #     categories.add(f"complex-{gt}")

            if any(word in batch["raw_question"][i] for word in [
                    "left of",
                    "right of",
                    "behind",
                    "front of",
            ]):
                categories.add("positional")

            if any(word in batch["raw_question"][i]
                   for word in ["type", "types"]):
                categories.add("type")

            if int(batch["answer"][i]) % 2 == 0:
                categories.add("even")
            if int(batch["answer"][i]) % 2 == 1:
                categories.add("odd")

            if hasattr(self.engine.dataset[self.mode], "own_numbers"):
                own_numbers = self.engine.dataset[self.mode].own_numbers
                opposite_numbers = self.engine.dataset[
                    self.mode].opposite_numbers
                if int(batch["answer"][i]) in own_numbers:
                    categories.add("overall-own")
                    if main_cat is not None:
                        categories.add(f"{main_cat}-own")
                if int(batch["answer"][i]) in opposite_numbers:
                    categories.add("overall-opposite")
                    if main_cat is not None:
                        categories.add(f"{main_cat}-opposite")

            for cat in categories:
                if "pred" in net_out:
                    self.answers[cat]["pred"].append(net_out["pred"][i].item())
                self.answers[cat]["ans"].append(pred)
                self.answers[cat]["gt"].append(gt)

                if "final_attention_map" in net_out:
                    thresh_prediction = (net_out["final_attention_map"][i] >
                                         0.5).sum()
                    # breakpoint()
                    self.answers[cat]["hard.ans"].append(
                        thresh_prediction.item())

        # GROUNDING
        if "scores" in net_out and "gt_bboxes" in batch:
            Logger()("Computing COCO grounding")
            bsize = logits.shape[0]
            # compute grounding
            ious = []
            threshold = self.score_threshold_grounding
            for i in range(bsize):
                gt = batch["answer"][i]
                scores = net_out["scores"][i]  # (regions, 1)
                selection = (scores >= threshold).view((scores.shape[0], ))
                coords = batch["coord"][i]
                coord_thresh = coords[selection]
                iou, inter, union, ioo, iogt = compute_iou(
                    batch["gt_bboxes"][i],
                    coord_thresh.cpu().numpy())

                ious.append(iou)
                self.ious["overall"].append(iou)
                self.ious[gt].append(iou)
                if batch["answer"][i] != 0:
                    self.ious_nonzero["overall"].append(iou)
                    self.ious_nonzero[gt].append(iou)

                # try another method
                width = batch["img_width"][i]
                height = batch["img_height"][i]
                img_gt = np.full((width, height), False, dtype=bool)  # (x, y)
                img_proposed = np.zeros((width, height))
                for bbox in batch["gt_bboxes"][i]:
                    x, y, x2, y2 = [round(x) for x in bbox]
                    img_gt[x:x2, y:y2] = True
                scores = net_out["scores"][i]
                candidate_bbox = list(
                    zip(
                        batch["coord"][i].tolist(),
                        scores.view((scores.shape[0], )).cpu().tolist(),
                    ))

                for bbox, score in candidate_bbox:
                    x, y, x2, y2 = [round(x) for x in bbox]
                    img_proposed[x:x2, y:y2] += score
                thresh = img_proposed >= threshold
                intersection = thresh & img_gt
                union = thresh | img_gt
                union_sum = union.sum()
                inter_sum = intersection.sum()
                thresh_sum = thresh.sum()
                img_gt_sum = img_gt.sum()

                if union_sum == 0:
                    iou_sum = 1.0
                else:
                    iou_sum = inter_sum / union_sum
                if thresh_sum != 0:
                    ioo_sum = inter_sum / thresh_sum
                else:
                    ioo_sum = 1.0

                if img_gt_sum != 0:
                    iogt_sum = inter_sum / img_gt_sum
                else:
                    iogt_sum = 1.0

                self.ious_sum["overall"].append(iou_sum)
                self.ious_sum[gt].append(iou_sum)
                if batch["answer"][i] != 0:
                    self.ious_sum_nonzero["overall"].append(iou_sum)
                    self.ious_sum_nonzero[gt].append(iou_sum)

                # try a third method
                iou_boxes = 0
                iogt_boxes = 0
                ioo_boxes = 0
                for bbox, score in candidate_bbox:
                    iou_box, _, _, ioo_box, iogt_box = compute_iou(
                        batch["gt_bboxes"][i], [bbox])
                    iou_boxes += iou_box * score
                    ioo_boxes += ioo_box * score
                    iogt_boxes += iogt_box * score

                if "pred" in net_out:
                    pred = net_out["pred"][i].item()
                elif "counter-pred" in net_out:
                    pred = net_out["counter-pred"][i].item()
                    # print("counter-pred", pred)

                iou_boxes_norm = iou_boxes / pred
                ioo_boxes_norm = ioo_boxes / pred
                iogt_boxes_norm = iogt_boxes / pred

                # average precision
                # Predicted bounding boxes : numpy array [n, 4]
                # Predicted classes: numpy array [n]
                # Predicted confidences: numpy array [n]
                # Ground truth bounding boxes:numpy array [m, 4]
                # Ground truth classes: numpy array [m]
                # pred_bb1 = coords
                pred_bb = coords.cpu().numpy()
                pred_cls = np.zeros((len(pred_bb)))
                pred_conf = scores.view((scores.shape[0], )).cpu().numpy()
                gt_bb = np.array(batch["gt_bboxes"][i])
                gt_cls = np.zeros(len(gt_bb))
                if len(gt_bb) > 0:
                    for t in THRESHOLDS_mAP:
                        # breakpoint()
                        try:
                            self.mean_ap[t].evaluate(pred_bb, pred_cls,
                                                     pred_conf, gt_bb, gt_cls)
                        except IndexError:
                            traceback.print_exc()
                            breakpoint()

                self.all_ious.append([
                    batch["question_id"][i],
                    batch["name"][i],
                    gt,
                    batch["gt_bboxes"][i],
                    pred,
                    round(pred),
                    candidate_bbox,
                    iou,
                    iou_sum,
                    ioo,
                    ioo_sum,
                    iogt,
                    iogt_sum,
                    iou_boxes,
                    iou_boxes_norm,
                    ioo_boxes,
                    ioo_boxes_norm,
                    iogt_boxes,
                    iogt_boxes_norm,
                ])

            iou = np.mean(ious)
            out["iou"] = iou
        return out
    def process_even_odd(self):
        Logger()(
            f"Creating EvenOdd split for "
            f"train={self.train_selection} and proportion_opposite={self.proportion_opposite}"
        )
        # breakpoint()
        even_questions = []
        odd_questions = []
        for q in self.questions:
            if q["answer"] % 2 == 0:
                even_questions.append(q)
            else:
                odd_questions.append(q)

        # filter answers
        if "train" in self.split or "val" in self.split:
            mode = "train"
        else:
            mode = "test"

        if (mode, self.train_selection) in [("train", "even"), ("test", "odd")]:
            own_questions = even_questions
            opposite_questions = odd_questions
        elif (mode, self.train_selection) in [("train", "odd"), ("test", "even")]:
            own_questions = odd_questions
            opposite_questions = even_questions
        else:
            raise ValueError((mode, self.train_selection))

        opposite_questions_by_ans = defaultdict(list)
        for q in opposite_questions:
            if "issimple" in q:
                key = str(q["answer"]) + ("simple" if q["issimple"] else "complex")
                opposite_questions_by_ans[key].append(q)
            else:
                opposite_questions_by_ans[q["answer"]].append(q)

        # select proportion  opposite
        Logger()(
            f"Number of opposite questions by ans: "
            f"{ {a: len(opposite_questions_by_ans[a]) for a in opposite_questions_by_ans} }"
        )
        opposite_selected = []
        for ans in opposite_questions_by_ans:
            tot = len(opposite_questions_by_ans[ans])
            num_sample = int(tot * self.proportion_opposite)
            if num_sample == 0 and self.proportion_opposite != 0:
                num_sample = min(10, len(opposite_questions_by_ans[ans]))
            selected = random.sample(opposite_questions_by_ans[ans], num_sample)
            Logger()(f"Number of questions selected for ans {ans}: {len(selected)}")
            opposite_selected += selected

        Logger()(
            f"Dataset : Split: {self.split}, train_selection={self.train_selection}"
        )
        Logger()(f"Number of own questions: {len(own_questions)}")
        Logger()(f"Number of opposite questions: {len(opposite_selected)}")
        own_questions = own_questions + opposite_selected

        # save
        path = self.get_path_questions_even_odd()
        with open(path, "w") as f:
            json.dump(own_questions, f)
    def __init__(
        self,
        dir_data,
        dir_coco,
        dir_vg,
        split,
        val_size=0.05,
        image_features="default",
        background_coco=None,
        background_vg=None,
        background=False,
        background_merge=2,
        proportion_opposite=0.0,  # not used
        train_selection=None,  # not used
        no_features=False,
        path_questions=None,
        sampling=None,
        shuffle=None,
        batch_size=None,
    ):
        super().__init__()
        self.dir_data = dir_data
        self.dir_coco = dir_coco
        self.dir_vg = dir_vg
        self.split = split
        self.image_features = image_features
        self.dir_coco_lvis = "data/vqa/coco/extract_rcnn/lvis"
        self.dir_vg_lvis = "data/vqa/vgenome/extract_rcnn/lvis"
        self.background_coco = background_coco
        self.background_vg = background_vg
        self.background = background
        self.background_merge = background_merge
        self.no_features = no_features
        self.val_size = val_size
        self.path_questions = path_questions  # to override path to questions (default dir_data/split.json)
        self.sampling = sampling
        self.shuffle = shuffle
        self.batch_size = batch_size

        if self.dir_coco.endswith(".zip"):
            self.zip_coco = None  # lazy loading zipfile.ZipFile(self.dir_coco)
        if self.dir_vg.endswith(".zip"):
            self.zip_vg = None  # lazy loading zipfile.ZipFile(self.dir_vg)
        if self.background_coco is not None and self.background_coco.endswith(".zip"):
            self.zip_bg_coco = None  # zipfile.ZipFile(self.background_coco)
        if self.background_vg is not None and self.background_vg.endswith(".zip"):
            self.zip_bg_vg = None  # lazy loading zipfile.ZipFile(self.background_vg)
        if self.dir_coco.endswith(".lmdb"):
            self.lmdb_coco = None

        if self.split not in ["train", "test"]:
            self.process_split()

        # path = os.path.join(self.dir_data, "processed", "questions.json")
        q_path = self.get_path_questions()  # train or test
        Logger()("Loading questions")
        with open(q_path) as f:
            self.questions = json.load(f)

        self.path_wid_to_word = os.path.join(
            self.dir_data, "processed", "wid_to_word.pth"
        )
        if os.path.exists(self.path_wid_to_word):
            self.wid_to_word = torch.load(self.path_wid_to_word)
        else:
            os.makedirs(os.path.join(self.dir_data, "processed"), exist_ok=True)
            word_list = self.get_token_list()
            self.wid_to_word = {wid + 1: word for wid, word in enumerate(word_list)}
            torch.save(self.wid_to_word, self.path_wid_to_word)

        self.word_to_wid = {word: wid for wid, word in self.wid_to_word.items()}

        self.aid_to_ans = [str(a) for a in list(range(16))]
        self.ans_to_aid = {ans: i for i, ans in enumerate(self.aid_to_ans)}
        self.collate_fn = bootstrap_tf.Compose(
            [
                bootstrap_tf.ListDictsToDictLists(),
                bootstrap_tf.PadTensors(
                    use_keys=[
                        "question",
                        "pooled_feat",
                        "cls_scores",
                        "rois",
                        "cls",
                        "cls_oh",
                        "norm_rois",
                    ]
                ),
                # bootstrap_tf.SortByKey(key='lengths'), # no need for the current implementation
                bootstrap_tf.StackTensors(),
            ]
        )