Example #1
0
    def collate_batch(self, batch):
        v_collate = default_collate
        visual_inputs = v_collate([d["vid"] for d in batch])  # (B, T, 3, H, W)
        # group data
        text_examples = flat_list_of_lists([d["examples"] for d in batch])
        n_examples_list = [d["n_examples"] for d in batch]  # (B, )
        # group elements data
        # directly concatenate question and option as a single seq.
        text_str_list = flat_list_of_lists([d["options_str_list"] for d in text_examples])  # (B * n_options, )
        batch_enc = self.tokenizer.batch_encode_plus(
            text_str_list,
            max_length=self.max_length,
            pad_to_max_length=True,
            return_tensors="pt"
        )
        text_input_ids = batch_enc.input_ids  # (B, L)
        text_input_mask = batch_enc.attention_mask  # (B, L)

        question_ids = [d["qid"] for d in text_examples]
        collated_batch = dict(
            visual_inputs=visual_inputs,  # (B, #frm, H, W, C)
            text_input_ids=text_input_ids,
            text_input_mask=text_input_mask,
            question_ids=question_ids,  # list(int), example ids,
            meta=dict(text_examples=text_examples, text_str_list=text_str_list),
            labels=None,
            n_examples_list=n_examples_list  # used to create image feature copies.
        )
        return collated_batch
Example #2
0
    def collate_batch(self, batch):
        v_collate = default_collate
        visual_inputs = v_collate([d["vid"] for d in batch])  # (B, T, 3, H, W)
        # group data
        text_examples = flat_list_of_lists([d["examples"] for d in batch])
        n_examples_list = [d["n_examples"] for d in batch]  # (B, )
        # group elements data
        # directly concatenate question and option as a single seq.
        if self.task_type in ["action", "transition"]:
            text_str_list = flat_list_of_lists(
                [[d["q_str"] + " " + d["options_str_list"][i] for i in range(self.n_options)]
                 for d in text_examples]
            )  # (B * n_options, )
        else:
            text_str_list = [d["q_str"] for d in text_examples]  # (B, )
        batch_enc = self.tokenizer.batch_encode_plus(
            text_str_list,
            max_length=self.max_length,
            pad_to_max_length=True,
            return_tensors="pt"
        )
        text_input_ids = batch_enc.input_ids  # (B, L)
        text_input_mask = batch_enc.attention_mask  # (B, L)

        labels = default_collate([int(d["label"]) for d in text_examples]) \
            if text_examples[0]["label"] is not None else None  # (B, #ans)
        question_ids = [d["question_id"] for d in text_examples]
        return dict(
            visual_inputs=visual_inputs,  # (B, #frm, H, W, C)
            text_input_ids=text_input_ids,
            text_input_mask=text_input_mask,
            question_ids=question_ids,
            labels=labels,
            n_examples_list=n_examples_list  # used to create image feature copies.
        )
Example #3
0
 def collate_batch(self, batch):
     if isinstance(batch[0]["img"], torch.Tensor):
         v_collate = default_collate
     else:
         v_collate = img_collate
     visual_inputs = v_collate([d["img"] for d in batch
                                ])  # (B, #frm=1 or T, 3, H, W)
     # group data
     text_examples = flat_list_of_lists([d["examples"] for d in batch])
     n_examples_list = [d["n_examples"] for d in batch]  # (B, )
     # group elements data
     batch_enc = self.tokenizer.batch_encode_plus(
         [d["text_str"] for d in text_examples],
         max_length=self.max_length,
         pad_to_max_length=True,
         return_tensors="pt")
     text_input_ids = batch_enc.input_ids  # (B, L)
     if self.mlm:
         text_input_ids, mlm_labels = mask_batch_text_tokens(
             text_input_ids, self.tokenizer,
             is_train=self.is_train)  # make mlm data
     else:
         text_input_ids, mlm_labels = text_input_ids, None
     text_input_mask = batch_enc.attention_mask  # (B, L)
     itm_labels = default_collate([d["itm_label"]
                                   for d in text_examples])  # (B, )
     return dict(
         visual_inputs=visual_inputs,  # (B, #frm=1 or T, H, W, C)
         text_input_ids=text_input_ids,
         mlm_labels=mlm_labels,
         text_input_mask=text_input_mask,
         itm_labels=itm_labels,
         n_examples_list=
         n_examples_list  # used to create image feature copies.
     )
Example #4
0
 def collate_batch(self, batch):
     if isinstance(batch[0]["img"], torch.Tensor):
         v_collate = default_collate
     else:
         v_collate = img_collate
     visual_inputs = v_collate([d["img"] for d in batch])  # (B, #frm=1 or T, 3, H, W)
     # group data
     text_examples = flat_list_of_lists([d["examples"] for d in batch])
     n_examples_list = [d["n_examples"] for d in batch]  # (B, )
     # group elements data
     batch_enc = self.tokenizer.batch_encode_plus(
         [d["text_str"] for d in text_examples],
         max_length=self.max_length,
         pad_to_max_length=True,
         return_tensors="pt"
     )
     text_input_ids = batch_enc.input_ids  # (B, L)
     text_input_mask = batch_enc.attention_mask  # (B, L)
     labels = default_collate(
         [d["labels"] for d in text_examples]) \
         if text_examples[0]["labels"] is not None else None  # (B, #ans)
     question_ids = [d["question_id"] for d in text_examples]
     return dict(
         visual_inputs=visual_inputs,  # (B, #frm=1 or T, H, W, C)
         text_input_ids=text_input_ids,
         text_input_mask=text_input_mask,
         question_ids=question_ids,
         labels=labels,
         n_examples_list=n_examples_list  # used to create image feature copies.
     )
Example #5
0
def mk_input_group(key_grouped_examples,
                   max_n_example_per_group=2,
                   is_train=True,
                   example_unique_key=None):
    """ Re-organize examples into groups. Each input group will have a single image paired
    with X (X=max_n_example_per_img) examples. Images with total #examples > X will be
    split into multiple groups. In the case a group has < X examples, we will copy
    the examples to make the group has X examples.
    Args:
        key_grouped_examples: dict, each key is image/video id,
            each value is a list(example) associated with this image/video
        max_n_example_per_group: int, pair max #examples with each image/video.
           Note that each image can have multiple groups.
        is_train: bool, if True, copy the examples to make sure each input
            group has max_n_example_per_group examples.
        example_unique_key: str, used to make sure no inputs are discarded by matching
            the input and output ids specified by `example_unique_key`
    """
    input_groups = []  # each element is (id, list(example))
    for k, examples in key_grouped_examples.items():
        chunked_examples = chunk_list(examples,
                                      chunk_size=max_n_example_per_group,
                                      pad_to_divisible=is_train)
        for c in chunked_examples:
            # if len(c) == 0:
            #     continue
            input_groups.append((k, c))

    if example_unique_key is not None:
        print(
            f"Using example_unique_key {example_unique_key} to check whether input and output ids m"
        )
        # sanity check: make sure we did not discard any input example by accident.
        input_question_ids = flat_list_of_lists(
            [[sub_e[example_unique_key] for sub_e in e]
             for e in key_grouped_examples.values()])
        output_question_ids = flat_list_of_lists(
            [[sub_e[example_unique_key] for sub_e in e[1]]
             for e in input_groups])
        assert set(input_question_ids) == set(
            output_question_ids), "You are missing "
    return input_groups
Example #6
0
def repeat_tensor_rows(raw_tensor, row_repeats):
    """ repeat raw_tensor[i] row_repeats[i] times.
    Args:
        raw_tensor: (B, *)
        row_repeats: list(int), len(row_repeats) == len(raw_tensor)
    """
    assert len(raw_tensor) == len(raw_tensor), "Has to be the same length"
    if sum(row_repeats) == len(row_repeats):
        return raw_tensor
    else:
        indices = torch.LongTensor(
            flat_list_of_lists([[i] * r for i, r in enumerate(row_repeats)
                                ])).to(raw_tensor.device)
        return raw_tensor.index_select(0, indices)
Example #7
0
    def collate_batch(self, batch):
        v_collate = default_collate
        visual_inputs = v_collate([d["vid"] for d in batch])  # (B, T, 3, H, W)
        # group data
        text_examples = flat_list_of_lists([d["examples"] for d in batch])
        n_examples_list = [d["n_examples"] for d in batch]  # (B, )
        # group elements data
        # directly concatenate question and option as a single seq.
        text_str_list = [d["text_str"] for d in text_examples]  # (B, )
        batch_enc = self.tokenizer.batch_encode_plus(
            text_str_list,
            max_length=self.max_length,
            pad_to_max_length=True,
            return_tensors="pt"
        )
        text_input_ids = batch_enc.input_ids  # (B, L)
        text_input_mask = batch_enc.attention_mask  # (B, L)

        if "itm_label" in text_examples[0]:
            itm_labels = default_collate(
                [d["itm_label"] for d in text_examples])  # (B, )
        else:
            itm_labels = None

        if "id" in text_examples[0]:
            caption_ids = [d["id"] for d in text_examples]  # (B, )
        else:
            caption_ids = None
        collated_batch = dict(
            visual_inputs=visual_inputs,  # (B, #frm, H, W, C)
            text_input_ids=text_input_ids,
            text_input_mask=text_input_mask,
            caption_ids=caption_ids,  # list(int), example ids,
            labels=itm_labels,
            n_examples_list=n_examples_list  # used to create image feature copies.
        )
        if "vid_id" in batch[0] and len(batch) == 1:
            collated_batch["vid_id"] = batch[0]["vid_id"]
        return collated_batch
Example #8
0
def mk_vqa_dataloader(anno_path, img_lmdb_dir, cfg, tokenizer, is_train=True):
    """
    Returns:
        list(dict), each dict is
        {
            "filepath": str,
            "txt": str,
        }
    """
    if isinstance(anno_path, str):
        raw_datalist = load_jsonl(anno_path)
    else:
        raw_datalist = flat_list_of_lists([load_jsonl(p) for p in anno_path])

    if cfg.data_ratio != 1.0:
        random.shuffle(raw_datalist)
        raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]

    datalist = []
    for raw_d in raw_datalist:
        d = dict(
            txt=raw_d["question"],
            img_id=raw_d["image_id"],
            question_id=raw_d["question_id"],
        )
        if "labels" in raw_d:  # deal with test sets
            d["labels"] = raw_d["labels"]
        if "answer_type" in raw_d:
            d["answer_type"] = raw_d["answer_type"]
        datalist.append(d)

    grouped = defaultdict(list)  # examples grouped by image/video id
    for d in datalist:
        grouped[d["img_id"]].append(d)

    # each group has a single image with multiple questions
    group_datalist = mk_input_group(
        grouped,
        max_n_example_per_group=cfg.max_n_example_per_group
        if is_train else 1,  # force 1 in eval
        is_train=is_train,
        example_unique_key="question_id")

    ans2label = load_json(cfg.ans2label_path)
    dataset = ClipBertVQADataset(datalist=group_datalist,
                                 tokenizer=tokenizer,
                                 img_lmdb_dir=img_lmdb_dir,
                                 ans2label=ans2label,
                                 max_img_size=cfg.max_img_size,
                                 max_txt_len=cfg.max_txt_len)
    LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, "
                f"each group {cfg.max_n_example_per_group if is_train else 1}")
    if cfg.do_inference:
        batch_size = cfg.inference_batch_size
    else:
        batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
    sampler = DistributedSampler(dataset,
                                 num_replicas=hvd.size(),
                                 rank=hvd.rank(),
                                 shuffle=is_train)
    vqa_collator = VQACollator(tokenizer=tokenizer, max_length=cfg.max_txt_len)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            sampler=sampler,
                            num_workers=cfg.n_workers,
                            pin_memory=cfg.pin_mem,
                            collate_fn=vqa_collator.collate_batch)
    return dataloader