Пример #1
0
def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1, num_workers = 0, limit_source = [], restrict_source = None) -> DataTuple:
    # Decide which QA datasets would be used in pre-training.
    # Options: vqa, gqa, visual7w
    # Note: visual7w is a part of vgqa, we take the name here.
    qa_sets = args.qa_sets
    if qa_sets is not None:
        qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))

    # Build dataset, data loader, and evaluator.
    dset = LXMERTDataset(splits, qa_sets=qa_sets)
    tset = LXMERTTorchDataset(
        dset, 
        topk, 
        limit_source = limit_source, 
        use_visual_tag_flag = args.get("allow_tag_for_eval", False) # As this function is called for evaulation in our context
        )

    data_loader = DataLoader(
        tset, batch_size=bs,
        shuffle=shuffle, num_workers=num_workers,
        collate_fn= tset.custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x,
        drop_last=drop_last, pin_memory=args.get("pin_memory", True)
    )
    evaluator = LXMERTEvaluator(dset)
    print()

    return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator, vl_torchdset=tset)
Пример #2
0
    def random_mask_features(self, feats, boxes=None):
        mask_feats = deepcopy(feats)  #.copy()
        feat_mask = np.zeros(len(feats), dtype=np.float32)

        for i in range(len(feats)):
            prob = random.random()
            # mask token with probability
            if prob < args.obj_mask_rate:
                feat_mask[i] = 1.

                prob /= args.obj_mask_rate

                # 80% randomly change token to zero feat
                if prob < 0.8:
                    mask_feats[i, :] = 0.

                # 10% randomly change token to random feat
                elif prob < 0.9:
                    if not args.get("disable_random_feat",
                                    False) and not args.get(
                                        "inbatch_random", False):
                        mask_feats[i, :] = self.random_feat()
                    if args.get("inbatch_random", False):
                        feat_mask[i] = 2.0  # special mark
                # -> rest 10% randomly keep current feat

                # Need to predict this feat
        return mask_feats, feat_mask
Пример #3
0
    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        train_results = []
        report_every = args.get("report_every", 100)
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, batch in iter_wrapper(enumerate(loader)):
                ques_id, feats, boxes, sent, tags, target = zip(*batch)
                self.model.train()
                self.optim.zero_grad()

                target = torch.stack(target).cuda()
                logit = self.model(feats, boxes, sent, tags)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                train_results.append(
                    pd.Series({"loss": loss.detach().mean().item()}))

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

                if i % report_every == 0 and i > 0:
                    print("Epoch: {}, Iter: {}/{}".format(
                        epoch, i, len(loader)))
                    print("    {}\n~~~~~~~~~~~~~~~~~~\n".format(
                        pd.DataFrame(train_results[-report_every:]).mean()))

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid and not args.get(
                        "special_test", False):
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
            if epoch >= 5:
                self.save("Epoch{}".format(epoch))
            print(log_str, end='')
            print(args.output)

        self.save("LAST")
Пример #4
0
def create_tags_pretrain(obj_labels, attr_labels, obj_confs, attr_confs, tokenizer, symbolic_vocab, visual_tags_box, feat_mask, use_bert_input = True):
    obj_labels_transformed = transfer_object_labels_to_symbolic_ids(obj_labels, attr_labels, symbolic_vocab, obj_confs, attr_confs)
    visual_tags_bert_words = []
    visual_tags_box_bert_input = []
    visual_tags_mlm_labels = []
    visual_tags_segment_ids = []

    for tag_index, tag in enumerate(obj_labels_transformed):
        tag_word = symbolic_vocab.id2word[tag]
        if args.get("use_segment_id_for_attr", False):
            seg_id = symbolic_vocab.get_seg_id(tag)
        sub_tokens = tokenizer.tokenize(tag_word)

        prob = random.random() 
        if prob < args.get('tag_mask_ratio', 0.15) or (feat_mask[tag_index] != 0 and random.random() < args.get("tag_joint_mask_ratio", 0.5)):

            new_prob = random.random()
            if new_prob < 0.8:
                for sub_token in sub_tokens:
                    visual_tags_bert_words.append("[MASK]")
            elif new_prob < 0.9:
                for sub_token in sub_tokens:
                    visual_tags_bert_words.append(random.choice(list(tokenizer.vocab.keys())))
            else:
                visual_tags_bert_words.extend(sub_tokens)
            
            for sub_token in sub_tokens:
                try:
                    visual_tags_mlm_labels.append(tokenizer.vocab[sub_token])
                except KeyError:
                    # For unknown words (should not occur with BPE vocab)
                    visual_tags_mlm_labels.append(tokenizer.vocab["[UNK]"])
                    logging.warning("Cannot find sub_token '{}' in vocab. Using [UNK] insetad".format(sub_token))

        else:
            for sub_token in sub_tokens:
                # no masking token (will be ignored by loss function later)
                visual_tags_bert_words.append(sub_token)
                visual_tags_mlm_labels.append(-1)

        # duplicate box
        for sub_token in sub_tokens:
            visual_tags_box_bert_input.append(visual_tags_box[tag_index])
            if args.get("use_segment_id_for_attr", False):
                visual_tags_segment_ids.append(seg_id)
    visual_tags = tokenizer.convert_tokens_to_ids(visual_tags_bert_words)
    visual_tags_objective = visual_tags_mlm_labels
    visual_tags_mask = [1] * len(visual_tags)
    visual_tags_box = visual_tags_box_bert_input

    visual_tags_segment_ids = None

    return visual_tags, visual_tags_objective, visual_tags_mask, visual_tags_box, visual_tags_segment_ids
Пример #5
0
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            valid_bsize = args.get("valid_batch_size", 16)
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=valid_bsize,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.get("load_lxmert_pretrain", None) is not None:
            load_lxmert_from_pretrain_noqa(args.load_lxmert_pretrain,
                                           self.model)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.model.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
Пример #6
0
 def save(self, name):
     torch.save(self.model.state_dict(),
                os.path.join(args.output, "%s_LXRT.pth" % name))
     
     if args.get("save_optimizer", False) and "Step" not in name:
         torch.save(self.optim.state_dict(),
                os.path.join(args.output, "%s_LXRT_optimizer.pth" % name))
Пример #7
0
    def __init__(self, max_seq_length):
        super().__init__()
        self.max_seq_length = max_seq_length

        self.tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True
        )

        # Build model
        self.model = LXRTPretraining.from_pretrained(
            "bert-base-uncased",
            args = args,
            task_mask_lm=args.task_mask_lm,
            task_obj_predict=args.task_obj_predict,
            task_matched=args.task_matched,
            task_qa=args.task_qa,
            visual_losses=args.visual_losses,
            num_answers= args.num_answers if args.get("num_answers", None) else train_tuple.dataset.answer_table.num_answers
        )

        # Weight initialization and loading
        if args.from_scratch:
            print("Train from Scratch: re-initialize all BERT weights.")
            self.model.apply(self.model.init_bert_weights)

        if args.get("use_tag_symbolic_embedding", False):
            self.model.bert.embeddings.initialize_symbolic_embeddings(symbolic_vocab.get_symbolic_list(self.tokenizer))
            self.model.special_initialize_pretraining_head()
        
        if args.get("hybrid_embedding", False):
            self.model.bert.embeddings.initialize_visual_position_type_embeddings()
        
        if args.load_lxmert is not None:
            # Load lxmert would not load the answer head.
            self.load_lxmert(args.load_lxmert)
        
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model = nn.DataParallel(self.model)
        
        self.global_step = 0
Пример #8
0
def create_tags(obj_labels, attr_labels, obj_confs, attr_confs, tokenizer, symbolic_vocab, visual_tags_box, use_bert_input = True, record_index = None):
    obj_labels_transformed = transfer_object_labels_to_symbolic_ids(obj_labels, attr_labels, symbolic_vocab, obj_confs, attr_confs)
    visual_tags_bert_words = []
    visual_tags_box_bert_input = []
    #visual_tags_mlm_labels = []
    visual_tags_segment_ids = []
    
    recorded_indexes = []
    counter = 0
    for tag_index, tag in enumerate(obj_labels_transformed):
        tag_word = symbolic_vocab.id2word[tag]
        if args.get("use_segment_id_for_attr", False):
            seg_id = symbolic_vocab.get_seg_id(tag)
        sub_tokens = tokenizer.tokenize(tag_word)

        for sub_token in sub_tokens:
            # no masking token (will be ignored by loss function later)
            visual_tags_bert_words.append(sub_token)
            #visual_tags_mlm_labels.append(-1)
            if tag_index == record_index:
                recorded_indexes.append(counter)

            counter += 1

        # duplicate box
        for sub_token in sub_tokens:
            visual_tags_box_bert_input.append(visual_tags_box[tag_index])
            if args.get("use_segment_id_for_attr", False):
                visual_tags_segment_ids.append(seg_id)

    visual_tags = tokenizer.convert_tokens_to_ids(visual_tags_bert_words)
    visual_tags_mask = [1] * len(visual_tags)
    visual_tags_box = visual_tags_box_bert_input
    visual_tags_segment_ids = None
    visual_tags_type = None

    if record_index is not None:
        return visual_tags, visual_tags_mask, visual_tags_box, visual_tags_type, visual_tags_segment_ids, recorded_indexes

    return visual_tags, visual_tags_mask, visual_tags_box, visual_tags_type, visual_tags_segment_ids
Пример #9
0
    def load(self, path, t_total):
        print("Load model from %s" % path)
        state_dict = torch.load("%s_LXRT.pth" % path)
        #self.model.load_state_dict(state_dict)
        from qa_answer_table import load_state_dict_flexible
        load_state_dict_flexible(self.model, state_dict)

        optimizer_path = "{}_LXRT_optimizer.pth".format(path)
        if os.path.exists(optimizer_path) and args.get("load_optimizer", True):
            print("Load optimizer from {}".format(optimizer_path))

            loaded_optim = torch.load(optimizer_path)
            if args.get("reset_schedule", False):
                for group in loaded_optim["param_groups"]:
                    group['lr'] = args.lr
                    group['warmup'] = args.warmup_ratio
                    group["t_total"] = t_total

                    for p in group['params']:
                        loaded_optim["state"][p]["step"]
                        loaded_optim["state"][p]["step"] = 0
            self.optim.load_state_dict(loaded_optim)
Пример #10
0
def transfer_object_labels_to_symbolic_ids(obj_labels, attribute_labels, symbolic_vocab, obj_confs = None, attr_confs = None):
    return_list = []

    for index in range(len(obj_labels)):
        prob = random.random()
        if prob < args.get("insert_attr_ratio", 0.0):
            if args.get("kl_divergence", False):
                if args.get("non_top1_sampling", False):
                    p = attr_confs[index][attribute_labels[index]]
                    p = p / p.sum()
                    attr_label_i = np.random.choice(attribute_labels[index], p=p)
                    #attr_label_i = np.random.choice(attr_confs.shape[-1], p=attr_confs[index])
                else:
                    attr_label_i = attribute_labels[index, 0]
            else:
                attr_label_i = attribute_labels[index]
            return_list.append(symbolic_vocab.word2id[symbolic_vocab.attr_id2word(attr_label_i)])
        else:
            if args.get("kl_divergence", False):
                if args.get("non_top1_sampling", False):
                    new_obj_confs = deepcopy(obj_confs)
                    new_obj_confs[new_obj_confs<0.1] = 0
                    p = new_obj_confs[index][obj_labels[index]]
                    sum_p = p.sum()
                    if sum_p == 0:
                        obj_label_i = obj_labels[index, 0]
                    else:
                        p = p / sum_p
                        obj_label_i =np.random.choice(obj_labels[index], p=p)
                        #obj_label_i = np.random.choice(obj_confs.shape[-1], p=obj_confs[index])
                else:
                    obj_label_i = obj_labels[index, 0]
            else:
                obj_label_i = obj_labels[index]
            return_list.append(symbolic_vocab.word2id[symbolic_vocab.obj_id2word(obj_label_i)])
    return np.array(return_list, dtype=np.int64)
Пример #11
0
    def __init__(self,
                 ann_file,
                 pretrained_model_name,
                 tokenizer=None,
                 seq_len=64,
                 min_seq_len=64,
                 encoding="utf-8",
                 on_memory=True,
                 **kwargs):
        assert on_memory, "only support on_memory mode!"

        self.tokenizer = tokenizer if tokenizer is not None else BertTokenizer.from_pretrained(
            pretrained_model_name)
        self.vocab = self.tokenizer.vocab
        self.seq_len = seq_len
        self.min_seq_len = min_seq_len
        self.on_memory = on_memory
        self.ann_file = ann_file
        self.encoding = encoding
        self.test_mode = False

        self.do_no_fill = False

        self.use_mismatch_objective = args.get("task_matched", False)
        #self.load_corpus_with_passages()

        # load samples into memory
        if on_memory:
            if self.use_mismatch_objective:
                #self.corpus = self.load_corpus_with_passages_preprocess()
                self.load_corpus_with_passages_preprocess()
            else:
                self.corpus = self.load_corpus()
        if args.get("presegment_sentence", False):
            self.presegment_sentence()
        print("Using {} with {} data.\n\n".format(self.ann_file, len(self)))
Пример #12
0
    def __init__(self,
                 datasets,
                 batch_size,
                 upsample_ratios=[1, 1, 1],
                 reduce_to_non_batch_sampler=False):
        self.datasets = datasets
        self.batch_size = batch_size

        self.lengths = [len(i) for i in self.datasets]
        self.upsample_ratios = upsample_ratios
        self.rotate_index = [0] * len(self.upsample_ratios)
        self.reduce_to_non_batch_sampler = reduce_to_non_batch_sampler

        _flag = False
        for i in self.upsample_ratios:
            if i < 1:
                _flag = True

        self.all_indexes = [torch.randperm(i).tolist() for i in self.lengths]
        assert (not args.get("old_sampler", False))

        if args.get("gradient_accumulation_steps", None):
            self.batch_size = batch_size * args.gradient_accumulation_steps
        self.prepare_indexes()
Пример #13
0
    def train_batch(self, optim, batch):
        
        gradient_accumulation_steps = args.get("gradient_accumulation_steps", 1)
        if (self.global_step + 1) % gradient_accumulation_steps == 0:
            optim.zero_grad()
        loss, losses, ans_logit, losses_dict = self.forward(batch)
        if args.multiGPU:
            loss = loss.mean()
            losses = losses.mean(0)
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps

        loss.backward()
        if (self.global_step + 1) % gradient_accumulation_steps == 0:
            nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
            optim.step()

        return loss.item(), losses.cpu().numpy(), ans_logit, losses_dict
Пример #14
0
 def create_in_batch_random_feat(self, example, example_index,
                                 all_examples):
     if args.get("inbatch_random",
                 False) and example.visual_feats[0] is not None:
         feats, _ = example.visual_feats
         feat_mask = example.obj_labels["feat"][1]
         #original_feats = example.obj_labels["feat"][0]
         for i in range(len(feat_mask)):
             if feat_mask[i] == 2:
                 feat_mask[i] = 1
                 select_index = random.randint(0, len(all_examples) - 1)
                 while select_index == example_index:
                     select_index = random.randint(0, len(all_examples) - 1)
                 select_index_j = random.randint(0, len(feat_mask) - 1)
                 while select_index_j == i:
                     select_index_j = random.randint(0, len(feat_mask) - 1)
                 feats[i] = all_examples[select_index].obj_labels["feat"][
                     0][select_index_j]
     return example
Пример #15
0
    def prepare_indexes(self):
        self.all_batched_indexes = []
        current_index = 0
        for index, i in enumerate(self.lengths):
            #if args.get("debug", False):
            #    random_indexes = list(range(i))
            #else:
            tmp_indexes = []

            if self.upsample_ratios[index] < 1:
                sample_num = int(1 / self.upsample_ratios[index])
                random_indexes = self.all_indexes[index][
                    self.rotate_index[index]:][::sample_num]

                self.rotate_index[
                    index] = self.rotate_index[index] + 1  #% sample_num
                if self.rotate_index[index] == sample_num:
                    self.all_indexes[index] = torch.randperm(i).tolist()
                    self.rotate_index[index] = 0  # Reset rotate index

                random.shuffle(random_indexes)
                random_indexes = [j + current_index for j in random_indexes]
                random_indexes = chunks(random_indexes, self.batch_size)
                #self.all_batched_indexes.extend(random_indexes)
            else:
                random_indexes = torch.randperm(i).tolist()
                random_indexes = [j + current_index for j in random_indexes]
                random_indexes = chunks(random_indexes, self.batch_size)
                #self.all_batched_indexes.extend(random_indexes)

            random.shuffle(random_indexes)
            self.all_batched_indexes.append(random_indexes)

            if self.upsample_ratios[index] > 1:
                for k in range(self.upsample_ratios[index] - 1):
                    #if args.get("debug", False):
                    #    random_indexes = list(range(i))
                    #else:
                    random_indexes = torch.randperm(i).tolist()

                    random_indexes = [
                        j + current_index for j in random_indexes
                    ]

                    random_indexes = chunks(random_indexes, self.batch_size)
                    #self.all_batched_indexes.extend(random_indexes)

                    random.shuffle(random_indexes)
                    self.all_batched_indexes[index].extend(random_indexes)

            current_index += i

        all_flatterned_indexes = []
        original_recorder = [len(i) for i in self.all_batched_indexes]
        original_recorder = [
            i / sum(original_recorder) for i in original_recorder
        ]
        index_recorder = np.array(
            [len(i) - 1 for i in self.all_batched_indexes])

        while np.any(index_recorder >= 0):
            choosed_index = np.random.choice(len(original_recorder),
                                             p=original_recorder)
            if index_recorder[choosed_index] >= 0:
                all_flatterned_indexes.append(
                    self.all_batched_indexes[choosed_index][
                        index_recorder[choosed_index]])
                index_recorder[choosed_index] -= 1

        self.all_batched_indexes = all_flatterned_indexes

        if self.reduce_to_non_batch_sampler:
            new_ = []
            for i in self.all_batched_indexes:
                for j in i:
                    new_.append([j])
            self.all_batched_indexes = new_

        if args.get("gradient_accumulation_steps", None):
            flattened_indexes = []
            for indexes in self.all_batched_indexes:
                flattened_indexes.extend(indexes)
            self.all_batched_indexes = chunks(
                flattened_indexes,
                self.batch_size // args.gradient_accumulation_steps)
        return current_index
Пример #16
0
    def custom_collact_fn(self, examples):

        hybrid_num = random.randint(args.get("hybrid_min", 2),
                                    args.get("hybrid_max", 34))

        train_features = [
            self.create_in_batch_random_feat(example,
                                             example_index,
                                             all_examples=examples)
            for example_index, example in enumerate(examples)
        ]

        if train_features[0].input_ids is not None:
            # language Inputs
            input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
            input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
            segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)

            # Language Prediction
            lm_labels = torch.tensor([f.lm_label_ids for f in train_features],
                                     dtype=torch.long)
        else:
            input_ids = None
            input_mask = None
            segment_ids = None
            lm_labels = None

        if train_features[0].visual_feats[0] is not None:
            # Visual Inputs
            if isinstance(train_features[0].visual_feats[0],
                          torch.FloatTensor):
                feats = torch.stack(
                    [f.visual_feats[0] for f in train_features])
            else:
                feats = torch.from_numpy(
                    np.stack([f.visual_feats[0] for f in train_features]))
            pos = torch.from_numpy(
                np.stack([f.visual_feats[1] for f in train_features]))
            # Visual Prediction
            obj_labels = {}
            for key in args.visual_losses.split(
                    ","):  #('obj', 'attr', 'feat'):
                visn_labels = torch.from_numpy(
                    np.stack([f.obj_labels[key][0] for f in train_features]))
                #if self.custom_coco_data:
                #    visn_mask = torch.ones(visn_labels.size(0), visn_labels.size(1)).float().cuda()
                #else:
                visn_mask = torch.from_numpy(
                    np.stack([f.obj_labels[key][1] for f in train_features]))
                assert visn_labels.size(0) == visn_mask.size(
                    0) and visn_labels.size(1) == visn_mask.size(1)
                obj_labels[key] = (visn_labels, visn_mask)
            if args.get('task_nlvr2', False):
                visual_feats_seg_ids = []
                for i in range(feats.size(0)):
                    visual_feats_seg_ids.append([0] * 36 + [1] * 36)
                visual_feats_seg_ids = torch.tensor(visual_feats_seg_ids,
                                                    dtype=torch.int64)
            else:
                visual_feats_seg_ids = None
        else:
            feats = None
            pos = None
            obj_labels = None
            visual_feats_seg_ids = None

        if train_features[0].visual_tags is not None:
            # do padding
            tag_max_length = max([len(f.visual_tags) for f in train_features])
            for f in train_features:
                current_tag_length = len(f.visual_tags)
                if current_tag_length < tag_max_length:
                    f.visual_tags = f.visual_tags + [0] * (tag_max_length -
                                                           current_tag_length)
                    f.visual_tags_objective = f.visual_tags_objective + [
                        -1
                    ] * (tag_max_length - current_tag_length)
                    f.visual_tags_mask = f.visual_tags_mask + [0] * (
                        tag_max_length - current_tag_length)
                    f.visual_tags_box = f.visual_tags_box + [
                        np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
                    ] * (tag_max_length - current_tag_length)
                    f.visual_tags_box = np.stack(f.visual_tags_box)
                    if f.visual_tags_segment_ids is not None:
                        f.visual_tags_segment_ids = f.visual_tags_segment_ids + [
                            0
                        ] * (tag_max_length - current_tag_length)

            visual_tags = torch.tensor([f.visual_tags for f in train_features],
                                       dtype=torch.long)
            visual_tags_mask = torch.tensor(
                [f.visual_tags_mask for f in train_features], dtype=torch.long)
            visual_tags_box = torch.from_numpy(
                np.stack([f.visual_tags_box for f in train_features]))
            visual_tags_objective = torch.tensor(
                [f.visual_tags_objective for f in train_features],
                dtype=torch.long)
            if train_features[0].visual_tags_mismatch is not None:
                visual_tags_mismatch = torch.tensor(
                    [f.visual_tags_mismatch for f in train_features],
                    dtype=torch.long)
            else:
                visual_tags_mismatch = None
            if train_features[0].visual_tags_segment_ids is not None:
                visual_tags_segment_ids = torch.tensor(
                    [f.visual_tags_segment_ids for f in train_features],
                    dtype=torch.long)
            else:
                visual_tags_segment_ids = None

            if args.get(
                    "tag_hard_max_length", None
            ) is not None and tag_max_length > args.tag_hard_max_length:
                # truncate the tag sequence
                visual_tags = visual_tags[:, :args.
                                          tag_hard_max_length].contiguous()
                visual_tags_mask = visual_tags_mask[:, :args.
                                                    tag_hard_max_length].contiguous(
                                                    )
                visual_tags_box = visual_tags_box[:, :args.
                                                  tag_hard_max_length].contiguous(
                                                  )
                visual_tags_objective = visual_tags_objective[:, :args.
                                                              tag_hard_max_length].contiguous(
                                                              )
                if visual_tags_mismatch is not None:
                    visual_tags_mismatch = visual_tags_mismatch[:, :args.
                                                                tag_hard_max_length].contiguous(
                                                                )
                if visual_tags_segment_ids is not None:
                    visual_tags_segment_ids = visual_tags_segment_ids[:, :args.
                                                                      tag_hard_max_length].contiguous(
                                                                      )

        else:
            visual_tags = None
            visual_tags_mask = None
            visual_tags_box = None
            visual_tags_objective = None
            visual_tags_mismatch = None
            visual_tags_segment_ids = None

        if train_features[0].is_matched is not None:
            matched_labels = torch.tensor(
                [f.is_matched for f in train_features], dtype=torch.long)
        else:
            matched_labels = None
        ans = torch.from_numpy(np.stack([f.ans for f in train_features]))

        if args.get("lxmert_style_nlvr", False):
            # Reorganize the inputs
            input_ids = input_ids.unsqueeze(1).expand(
                input_ids.size(0), 2, input_ids.size(-1)).contiguous().view(
                    -1, input_ids.size(-1)).contiguous()
            lm_labels = lm_labels.unsqueeze(1).expand(
                lm_labels.size(0), 2, lm_labels.size(-1)).contiguous().view(
                    -1, lm_labels.size(-1)).contiguous()
            input_mask = input_mask.unsqueeze(1).expand(
                input_mask.size(0), 2, input_mask.size(-1)).contiguous().view(
                    -1, input_mask.size(-1)).contiguous()

            visual_feats_seg_ids = None
            feats = feats.view(-1,
                               feats.size(1) // 2,
                               feats.size(-1)).contiguous()
            pos = pos.view(-1, pos.size(1) // 2, pos.size(-1)).contiguous()
            if args.get("use_visual_tag_flag", False):
                visual_tags = visual_tags.view(-1,
                                               visual_tags.size(1) //
                                               2).contiguous()
                visual_tags_box = visual_tags_box.view(
                    -1,
                    visual_tags_box.size(1) // 2,
                    visual_tags_box.size(-1)).contiguous()
                visual_tags_objective = visual_tags_objective.view(
                    -1,
                    visual_tags_objective.size(1) // 2).contiguous()
                visual_tags_mask = visual_tags_mask.view(
                    -1,
                    visual_tags_mask.size(1) // 2).contiguous()
        return [
            input_ids, segment_ids, input_mask, lm_labels, feats, pos,
            obj_labels, matched_labels, ans, visual_feats_seg_ids, visual_tags,
            visual_tags_mask, visual_tags_box, visual_tags_objective,
            visual_tags_mismatch, visual_tags_segment_ids
        ]
Пример #17
0
    def convert_example_to_features(self, example: InputExample,
                                    max_seq_length, tokenizer):

        if example.mlm_labels is not None:  # The data is already pre-masked
            input_ids = example.token_ids
            lm_label_ids = example.mlm_labels
            max_seq_len = example.max_seq_len + 2
            # Add [CLS] and [SEP]
            input_ids = tokenizer.convert_tokens_to_ids([
                "[CLS]"
            ]) + input_ids + tokenizer.convert_tokens_to_ids(["[SEP]"])
            lm_label_ids = [-1] + lm_label_ids + [-1]
            input_mask = [1] * len(input_ids)
            segment_ids = [0] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_len:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
                lm_label_ids.append(-1)

            features = InputFeatures(input_ids=input_ids,
                                     input_mask=input_mask,
                                     segment_ids=segment_ids,
                                     lm_label_ids=lm_label_ids,
                                     visual_feats=(None, None),
                                     obj_labels={
                                         'obj': (None, None),
                                         'attr': (None, None),
                                         'feat': (None, None),
                                     },
                                     is_matched=None,
                                     ans=-1,
                                     visual_tags=None,
                                     visual_tags_objective=None,
                                     visual_tags_mask=None,
                                     visual_tags_box=None,
                                     visual_tags_mismatch=None)
            return features

        if example.sent is not None:

            tokens = tokenizer.tokenize(example.sent.strip())

            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens) > max_seq_length - 2:
                tokens = tokens[:(max_seq_length - 2)]

            # Ge random words
            masked_tokens, masked_label = random_word(tokens, tokenizer)

            # concatenate lm labels and account for CLS, SEP, SEP
            masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]']
            input_ids = tokenizer.convert_tokens_to_ids(masked_tokens)

            # Mask & Segment Word
            lm_label_ids = ([-1] + masked_label + [-1])
            input_mask = [1] * len(input_ids)
            segment_ids = [0] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
                lm_label_ids.append(-1)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            assert len(lm_label_ids) == max_seq_length
        elif args.get("insert_cls", False):
            masked_tokens = ["[CLS]"]
            input_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
            input_mask = [1] * len(input_ids)
            segment_ids = [0] * len(input_ids)
            lm_label_ids = [-1]
        else:
            input_ids = None
            input_mask = None
            segment_ids = None
            lm_label_ids = None

        if example.use_visual_tag_flag and example.visual_feats[
                0] is not None:  # Let's do a hybrid embedding
            feat, boxes = example.visual_feats
            obj_labels, obj_confs = example.obj_labels
            attr_labels, attr_confs = example.attr_labels

            # Mask Image Features:
            masked_feat, feat_mask = self.random_mask_features(feat,
                                                               boxes=boxes)

            assert (args.non_exclusive_tags)
            assert (args.use_bert_input_for_tags)
            visual_tags, visual_tags_objective, visual_tags_mask, visual_tags_box, visual_tags_segment_ids = tag_data_utilis.create_tags_pretrain(
                obj_labels=obj_labels,
                attr_labels=attr_labels,
                obj_confs=obj_confs,
                attr_confs=attr_confs,
                tokenizer=self.tokenizer,
                symbolic_vocab=symbolic_vocab,
                visual_tags_box=boxes,
                feat_mask=feat_mask,
                use_bert_input=True)
        elif example.visual_feats[0] is not None:
            feat, boxes = example.visual_feats
            obj_labels, obj_confs = example.obj_labels
            attr_labels, attr_confs = example.attr_labels
            # Mask Image Features:
            masked_feat, feat_mask = self.random_mask_features(feat,
                                                               boxes=boxes)
            visual_tags = None
            visual_tags_objective = None
            visual_tags_mask = None
            visual_tags_box = None
            visual_mismatch_label = None
            obj_labels_transformed_mismatch = None
            visual_tags_box_mismatch = None
        else:
            masked_feat = None
            boxes = None
            obj_labels = None
            obj_confs = None
            attr_labels = None
            attr_confs = None
            feat_mask = None
            feat = None
            visual_tags = None
            visual_tags_objective = None
            visual_tags_mask = None
            visual_tags_box = None
            visual_mismatch_label = None
            obj_labels_transformed_mismatch = None
            visual_tags_box_mismatch = None

        # QA answer label
        if example.label is None or len(
                example.label) == 0 or example.is_matched != 1:
            # 1. No label 2. Label is pruned 3. unmatched visual + language pair
            ans = -1
        else:
            keys, values = zip(*example.label.items())
            if len(keys) == 1:
                ans = keys[0]
            else:
                value_sum = sum(values)
                prob = [value / value_sum for value in values]
                choice = np.random.multinomial(1, prob).argmax()
                ans = keys[choice]

        features = InputFeatures(
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            lm_label_ids=lm_label_ids,
            visual_feats=(masked_feat, boxes),
            obj_labels={
                'obj': (obj_labels, obj_confs),
                'attr': (attr_labels, attr_confs),
                'feat': (feat, feat_mask),
            },
            is_matched=example.is_matched,
            ans=ans,
            visual_tags=visual_tags,
            visual_tags_objective=visual_tags_objective,
            visual_tags_mask=visual_tags_mask,
            visual_tags_box=visual_tags_box,
            visual_tags_mismatch=None if not args.get('use_tag_mismatch', None)
            else visual_mismatch_label,
            obj_labels_transformed_mismatch=None
            if not args.get("use_tag_mismatch", None) else
            obj_labels_transformed_mismatch,
            visual_tags_box_mismatch=None
            if not args.get('use_tag_mismatch', None) else
            visual_tags_box_mismatch,
            use_visual_tag_flag=example.use_visual_tag_flag)
        return features
Пример #18
0
    def __getitem__(self, item: int):
        datum = self.data[item]

        img_id = datum['img_id']
        ques_id = datum['question_id']
        ques = datum['sent']

        if self.custom_coco_data:
            image_index = self.ids_to_index[img_id]
            obj_num = None
            feats = self.h5_features[image_index]
            boxes = self.h5_boxes[image_index]
            img_h = self.h5_wh[image_index][1]
            img_w = self.h5_wh[image_index][0]
            obj_confs = None
            attr_labels = None
            attr_confs = None
        elif self.use_h5_file:
            '''image_index = self.ids_to_index[img_id]
            obj_num = 36
            feats = self.h5_features[image_index]
            boxes = self.h5_boxes[image_index]
            img_h = self.h5_wh[image_index][1]
            img_w = self.h5_wh[image_index][0] '''
            image_index, obj_num, feats, boxes, img_h, img_w, obj_labels, obj_confs, attr_labels, attr_confs = self.image_feature_dataset[
                img_id]
        else:
            # Get image info
            img_info = self.imgid2img[img_id]
            obj_num = img_info['num_boxes']
            feats = img_info['features'].copy()
            boxes = img_info['boxes'].copy()
            assert obj_num == len(boxes) == len(feats)
            img_h, img_w = img_info['img_h'], img_info['img_w']

        # Normalize the boxes (to 0 ~ 1)
        boxes = boxes.copy()
        boxes[:, (0, 2)] /= img_w
        boxes[:, (1, 3)] /= img_h
        np.testing.assert_array_less(boxes, 1 + 1e-5)
        np.testing.assert_array_less(-boxes, 0 + 1e-5)

        if args.get("add_tags", False):
            tags = create_tags(obj_labels=obj_labels,
                               attr_labels=attr_labels,
                               obj_confs=None,
                               attr_confs=None,
                               tokenizer=self.tokenizer,
                               symbolic_vocab=self.symbolic_vocab,
                               visual_tags_box=boxes,
                               use_bert_input=True)
        else:
            tags = None

        # Provide label (target)
        if 'label' in datum:
            label = datum['label']
            target = torch.zeros(self.raw_dataset.num_answers)
            for ans, score in label.items():
                target[self.raw_dataset.ans2label[ans]] = score
            return ques_id, feats, boxes, ques, tags, target
        else:
            return ques_id, feats, boxes, ques, tags
Пример #19
0
    def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
        train_ld = train_tuple.loader

        # Optimizer
        from lxrt.optimization import BertAdam
        batch_per_epoch = len(train_ld)
        t_total = int(batch_per_epoch * args.epochs)
        warmup_ratio = args.get("warmup_ratio", 0.05)

        print("Total Iters: %d" % t_total)
        if args.get("t_total", None):
            t_total = args.t_total
            print("!! Changing to specified t_toal in args: {}".format(t_total))
        self.t_total = t_total
        warmup_iters = int(t_total * warmup_ratio)

        print("Batch per epoch: %d" % batch_per_epoch)
        print("Warm up Iters: %d" % warmup_iters)
        self.optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total)

        if args.load is not None:
            self.load(args.load, t_total = t_total)

        gradient_accumulation_steps = args.get("gradient_accumulation_steps", 1)
        # Train
        best_eval_loss = 9595.
        report_every = args.get("report_every", 100)

        custom_train_meter = TrainingMeter()
        
        for epoch in range(args.epochs):
            # Train
            self.model.train()
            total_loss = 0.
            total_losses = 0.
            uid2ans = {}

            for batch_id, batch in enumerate(tqdm(train_ld, total=len(train_ld))):
                if args.get("skip_training", False):
                    break

                loss, losses, logit, losses_dict = self.train_batch(self.optim, batch)
                total_loss += loss
                try:
                    total_losses += losses
                except:
                    pass

                if args.task_qa and batch[0].sent is not None:
                    assert(0) # Not used in our experiment

                    score, label = logit.max(1)
                    for datum, l in zip(batch, label.cpu().numpy()):
                        uid = datum.uid
                        ans = train_tuple.dataset.answer_table.id2ans(l)
                        uid2ans[uid] = ans
                
                for key, value in losses_dict.items():
                    losses_dict[key] = value.mean().item()  # make the losses scalar
                
                if "Masked LM" in losses_dict and losses_dict["Masked LM"] == 0:
                    del losses_dict["Masked LM"]

                custom_train_meter.update(losses_dict)

                if batch_id % report_every == 0 and batch_id > 0:
                    print("Folder: {} \n Epoch {} Iter: {}/{}".format(args.output, epoch, batch_id, len(train_ld)))
                    #print(pd.DataFrame(train_results[-report_every:]).mean())
                    custom_train_meter.report()
                    custom_train_meter.clean()
                    print()
                
                if args.get("save_step", -1) != -1 and self.global_step != 0 and (self.global_step // gradient_accumulation_steps) % args.save_step == 0:
                    self.save("Step{}".format(self.global_step))
                self.global_step += 1
            
            print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch))

            if args.task_qa:
                train_tuple.evaluator.evaluate(uid2ans, pprint=True)

            # Eval
            avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)

            if args.get("eval_on_train", False):
                print("On train set")
                self.evaluate_epoch(train_tuple, iters=-1)


            if avg_eval_loss < best_eval_loss:
                best_eval_loss = avg_eval_loss
                self.save("BEST_EVAL_LOSS")
            self.save("Epoch%02d" % (epoch+1))
Пример #20
0
    def load_custom_h5_version2(
            h5_file_name,
            on_memory=False,
            text_only=False):  # This version used in the conceptual caption
        if not text_only:
            h5_file_feature = h5py.File(
                h5_file_name.replace("no_features", "features"), "r")
        h5_file = h5py.File(h5_file_name, "r")

        if on_memory:
            print("Reading h5 {}".format(
                h5_file_name.replace("no_features", "features")))
            h5_features = sharearray.cache(
                h5_file_name.replace("no_features", "features").split("/")[-1],
                lambda: h5_file_feature['image_features'])
            gc.collect()
        else:
            if not text_only:
                h5_features = h5_file_feature['image_features']

        h5_boxes = sharearray.cache(
            "{}_{}".format(h5_file_name.split("/")[-1], "boxes"),
            lambda: h5_file['boxes'])
        h5_num_boxes = sharearray.cache(
            "{}_{}".format(h5_file_name.split("/")[-1], "num_boxes"),
            lambda: h5_file['num_boxes'])

        if not args.get("kl_divergence", False):
            h5_objects_id = sharearray.cache(
                "{}_{}".format(h5_file_name.split("/")[-1], "object_ids"),
                lambda: np.array(h5_file['object_ids'])[:, :, 0]
            )  #deepcopy(np.array(h5_file['object_ids'])[:, :, 0])
            h5_objects_conf = sharearray.cache(
                "{}_{}".format(h5_file_name.split("/")[-1], "object_pro"),
                lambda: np.array(h5_file['object_pro'])[:, :, 0]
            )  #deepcopy(np.array(h5_file['object_pro'])[:, :, 0])
            h5_attrs_id = sharearray.cache(
                "{}_{}".format(h5_file_name.split("/")[-1], "attribute_ids"),
                lambda: np.array(h5_file['attribute_ids'])[:, :, 0]
            )  #deepcopy(np.array(h5_file['attribute_ids'])[:, :, 0])
            h5_attrs_conf = sharearray.cache(
                "{}_{}".format(h5_file_name.split("/")[-1], "attribute_pro"),
                lambda: np.array(h5_file['attribute_pro'])[:, :, 0]
            )  #deepcopy(np.array(h5_file['attribute_pro'])[:, :, 0])
        else:
            h5_objects_id = deepcopy(np.array(h5_file['object_ids']))
            h5_objects_conf = deepcopy(np.array(h5_file['object_pro']))
            h5_attrs_id = deepcopy(np.array(h5_file['attribute_ids']))
            h5_attrs_conf = deepcopy(np.array(h5_file['attribute_pro']))
        gc.collect()

        img_h = deepcopy(np.array(h5_file['img_h'])).tolist()
        img_w = deepcopy(np.array(h5_file['img_w'])).tolist()
        wh_list = []
        for i in range(len(img_h)):
            wh_list.append((img_w[i], img_h[i]))

        h5_file.close()
        del h5_file
        gc.collect()

        if text_only:
            h5_features = [0] * len(h5_num_boxes)

        return h5_features, h5_boxes, h5_objects_id, h5_objects_conf, h5_attrs_id, h5_attrs_conf, wh_list, h5_num_boxes
Пример #21
0
def get_tuple_hybrid(splits: str, bs: int, shuffle=False, drop_last=False, num_workers=0, topk=-1, image_only_splits=None, text_only_splits = None, limit_source = [], restrict_source = None) -> DataTuple:
    # Decide which QA datasets would be used in pre-training.
    # Options: vqa, gqa, visual7w
    # Note: visual7w is a part of vgqa, we take the name here.
    qa_sets = args.qa_sets
    if qa_sets is not None:
        qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))

    # Three type of datasets: v&l, language, vision
    datasets_list_torch = []
    datasets_list = []

    if splits is not None:
        vl_dataset = LXMERTDataset(splits, qa_sets=qa_sets)
        vl_dataset_torch = LXMERTTorchDataset(vl_dataset, topk, limit_source = limit_source, randomized_pairing = args.get("randomized_pairing", False),  use_visual_tag_flag = args.get("use_visual_tag_flag", False))
        datasets_list.append(vl_dataset)
        datasets_list_torch.append(vl_dataset_torch)

    if text_only_splits is not None:
        text_only_datasets = []
        for split in text_only_splits.split("+"):
            if not("book_corpus" in split or "sbu" in split):
                text_only_dataset = LXMERTDataset(split, qa_sets=qa_sets)
                text_only_dataset_torch = LXMERTTorchDataset(text_only_dataset, topk, text_only=True, limit_source=limit_source)
                
                datasets_list.append(text_only_dataset)
                datasets_list_torch.append(text_only_dataset_torch)
                text_only_datasets.append(text_only_dataset_torch)
            else:
                text_only_dataset = None
                if "book_corpus" in split and args.get("text_shared_memory", False):
                    text_class = GeneralCorpusNP
                else:
                    #text_class = GeneralCorpus
                    pass
                text_only_dataset_torch = text_class(ann_file=args.book_corpus_path if "book_corpus" in split else args.sbu_path, pretrained_model_name="bert-base-uncased", tokenizer=None, seq_len=args.get("text_only_max_seq_len", 64), min_seq_len=args.get("text_only_min_seq_len", 64), encoding="utf-8", on_memory=True)
                datasets_list.append(text_only_dataset)
                datasets_list_torch.append(text_only_dataset_torch)
                text_only_datasets.append(text_only_dataset_torch)

    if image_only_splits is not None:
        if image_only_splits != "":
            image_only_dataset = LXMERTDataset(image_only_splits, qa_sets=qa_sets)
            image_only_dataset_torch = LXMERTTorchDataset(image_only_dataset, topk, image_only=True, use_visual_tag_flag = args.get("use_visual_tag_flag", False))
            datasets_list.append(image_only_dataset)
            datasets_list_torch.append(image_only_dataset_torch)

        if args.get("add_adhoc_google_cc_image_only", False):
            google_cc_dataset = LXMERTDataset("google_cc_train", qa_sets=qa_sets)
            google_cc_dataset_torch = LXMERTTorchDataset(google_cc_dataset, topk, image_only=True, use_visual_tag_flag=args.get("use_visual_tag_flag", False), available_split_for_cc = args.get("available_split_for_cc", [0]))
            datasets_list.append(google_cc_dataset)
            datasets_list_torch.append(google_cc_dataset_torch)
        
        if args.get("add_adhoc_open_image_image_only", False):
            open_image_dataset = LXMERTDataset("open_images_train", qa_sets=qa_sets)
            open_image_torch = LXMERTTorchDataset(open_image_dataset, topk, image_only=True, use_visual_tag_flag=args.get("use_visual_tag_flag", False))
            datasets_list.append(open_image_dataset)
            datasets_list_torch.append(open_image_torch)

    # Merge different datasets
    merged_dataset = ConcateDataset(datasets_list_torch)

    if args.task_qa:
        merged_dataset.answer_table = datasets_list[0].answer_table if datasets_list[0] is not None else None
    
    batch_sampler = CustomBatchSampler(merged_dataset.datasets, bs, upsample_ratios=args.get("upsample_ratios", [1,1,1]))
    try:
        custom_collact_fn = datasets_list_torch[0].custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x
    except:
        custom_collact_fn = datasets_list_torch[-1].custom_collact_fn if args.get('custom_collact_fn', False) else lambda x: x
    data_loader = DataLoader(
        merged_dataset, num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=custom_collact_fn,
        pin_memory=args.get("pin_memory", True)
    )
    if args.task_qa:
        evaluator = LXMERTEvaluator(datasets_list[0]) if datasets_list[0] is not None else None  # The evaluator is for task_qa so no need to have it
    else:
        evaluator = None
    print()

    if splits is not None:
        vl_torchdset = vl_dataset_torch
    else:
        vl_torchdset = datasets_list_torch[-1] # the last dataset

    return DataTuple(dataset=merged_dataset, torchdset=merged_dataset, loader=data_loader, evaluator=evaluator, vl_torchdset=vl_torchdset)
Пример #22
0
    def __init__(self, dataset: VQADataset, args):
        super().__init__()
        self.raw_dataset = dataset

        if args.tiny:
            topk = TINY_IMG_NUM
        elif args.fast:
            topk = FAST_IMG_NUM
        else:
            topk = None

        self.limit_to_symbolic_split = args.get("limit_to_symbolic_split",
                                                False)
        if self.limit_to_symbolic_split:
            dataDir = "/local/harold/ubert/bottom-up-attention/data/vg/"
            coco_ids = set()
            self.mapping_cocoid_to_imageid = {}
            with open(os.path.join(dataDir, 'image_data.json')) as f:
                metadata = json.load(f)
                for item in metadata:
                    if item['coco_id']:
                        coco_ids.add(int(item['coco_id']))
                        self.mapping_cocoid_to_imageid[int(
                            item['coco_id'])] = item["image_id"]

            from lib.data.vg_gqa import vg_gqa
            self.vg_gqa = vg_gqa(
                None,
                split="val" if self.raw_dataset.name == "minival" else "train",
                transforms=None,
                num_im=-1)

        self.custom_coco_data = args.get("custom_coco_data", False)
        self.use_h5_file = args.get("use_h5_file", False)
        if self.use_h5_file:
            self.image_feature_dataset = ImageFeatureDataset.create(
                dataset.splits,
                Split2ImgFeatPath,
                on_memory=args.get("on_memory", False))
            self.ids_to_index = self.image_feature_dataset.ids_to_index

            # Screen data
            used_data = []
            for datum in self.raw_dataset.data:
                if datum['img_id'] in self.ids_to_index:
                    used_data.append(datum)
        else:
            # Loading detection features to img_data
            img_data = []
            for split in dataset.splits:
                # Minival is 5K images in MS COCO, which is used in evaluating VQA/LXMERT-pre-training.
                # It is saved as the top 5K features in val2014_***.tsv
                load_topk = 5000 if (split == 'minival'
                                     and topk is None) else topk
                img_data.extend(
                    load_obj_tsv(os.path.join(
                        MSCOCO_IMGFEAT_ROOT,
                        '%s_obj36.tsv' % (SPLIT2NAME[split])),
                                 topk=load_topk))

            # Convert img list to dict
            self.imgid2img = {}
            for img_datum in img_data:
                self.imgid2img[img_datum['img_id']] = img_datum

            used_data = self.raw_dataset.data

        used_data = used_data[::args.get("partial_dataset", 1)]
        self.data = used_data

        # Only kept the data with loaded image features
        print("Use %d data in torch dataset" % (len(self.data)))
        print()

        if args.get("add_tags", False):
            self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                           do_lower_case=True)
            from lxrt.symbolic_vocabulary import SymbolicVocab
            self.symbolic_vocab = SymbolicVocab(args.objects_vocab,
                                                args.attributes_vocab)
Пример #23
0
    def __init__(self,
                 dataset: LXMERTDataset,
                 topk=-1,
                 sgg_dataset=None,
                 image_only=False,
                 text_only=False,
                 use_visual_tag_flag=False,
                 limit_source=[],
                 available_split_for_cc=None):
        super().__init__()
        self.raw_dataset = dataset
        self.name = '_'.join(self.raw_dataset.sources)
        if args.get('disable_mismatch_for_other_dataset', False):
            # Do not resample for datasets such as BookCorpus
            self.task_matched = args.task_matched if "book_corpus" in self.raw_dataset.sources else False
        else:
            self.task_matched = args.task_matched

        print(self.raw_dataset.sources)
        print(self.task_matched)
        print("\n\n\n")
        self.sgg_dataset = sgg_dataset
        self.image_only = image_only
        self.text_only = text_only
        self.use_visual_tag_flag = use_visual_tag_flag
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                       do_lower_case=True)
        self.task_nlvr2 = args.get("task_nlvr2", False)

        if args.tiny:
            topk = TINY_IMG_NUM
        elif args.fast:
            topk = FAST_IMG_NUM

        #self.fake_data = args.get("fake_data", False)
        self.custom_coco_data = args.get("custom_coco_data", False)
        self.use_h5_file = args.get("use_h5_file", False)
        if self.use_h5_file:
            if "google_cc_train" in dataset.sources:
                if args.get('change_split', False):
                    available_split_for_cc = [39]
                else:
                    available_split_for_cc = args.get("available_split_for_cc",
                                                      [0])
                sources = []
                split_map = {}
                for i in available_split_for_cc:
                    sources.append("google_cc_{}".format(i))
                    split_map["google_cc_{}".format(
                        i
                    )] = "data/google_concetual/butd_feat/train_no_features_split_{}_of_40_splits.h5".format(
                        i)
                self.image_feature_dataset = ImageFeatureDataset.create(
                    sources,
                    split_map,
                    load_custom_h5_version2=True,
                    text_only=self.text_only,
                    on_memory=False)
            elif "open_images_train" in dataset.sources:
                available_split_for_open_image = args.get(
                    "available_split_for_open_image", [0])
                sources = []
                split_map = {}
                for split_i, split_j, total_split in available_split_for_open_image:
                    sources.append("open_image_{}_{}".format(split_i, split_j))
                    split_map["open_image_{}_{}".format(
                        split_i, split_j
                    )] = "data/open_image/butd_feat/train_{}_no_features_split_{}_of_{}_splits.h5".format(
                        split_i, split_j, total_split)
                self.image_feature_dataset = ImageFeatureDataset.create(
                    sources,
                    split_map,
                    load_custom_h5_version2=True,
                    on_memory=False)
            else:
                self.image_feature_dataset = ImageFeatureDataset.create(
                    dataset.sources,
                    Split2ImgFeatPath_h5,
                    text_only=self.text_only,
                    load_custom_h5_version2=True
                    if "flickr_train" in dataset.sources else False,
                    on_memory=args.get("on_memory", False))
            self.ids_to_index = self.image_feature_dataset.ids_to_index

            # Screen data
            used_data = []
            for datum in self.raw_dataset.data:
                if datum['img_id'] in self.ids_to_index:
                    used_data.append(datum)
        else:
            # Original LXMERT. Load the dataset
            img_data = []
            for source in self.raw_dataset.sources:
                img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk))

            self.imgid2img = {}
            for img_datum in img_data:
                self.imgid2img[img_datum['img_id']] = img_datum

            # Filter out the dataset
            used_data = []
            for datum in self.raw_dataset.data:
                if datum['img_id'] in self.imgid2img:
                    used_data.append(datum)

        used_data = used_data[::args.get("partial_dataset", 1)]

        if sgg_dataset is not None:
            used_data = [
                datum for datum in used_data
                if str(datum["img_id"]) in self.sgg_dataset.imageids_to_index
            ]

        # Flatten the dataset (into one sent + one image entries)
        self.data = []

        record_img_id = set()

        remaining_set = set()
        for datum in used_data:
            # datum: {'img_id': 'COCO_train2014_000000318556', 'labelf': {'vqa': [{'no': 1}, {'yes': 1}, {'no': 1}, {'blue': 1, 'blue and white': 0.3}]}, 'sentf': {'mscoco': ['A very clean and well decorated empty bathroom', 'A blue and white bathroom with butterfly themed wall tiles.', 'A bathroom with a border of butterflies and blue paint on the walls above it.', 'An angled view of a beautifully decorated bathroom.', 'A clock that blends in with the wall hangs in a bathroom. '], 'vqa': ['Is the sink full of water?', 'Are there any butterflies on the tiles?', 'Is this bathroom in a hotel?', 'What color are the walls?']}}

            sentf = datum['sentf']
            for sents_cat, sents in sentf.items():
                if sents_cat in limit_source:
                    continue

                remaining_set.add(sents_cat)

                if sents_cat in datum['labelf']:
                    labels = datum['labelf'][sents_cat]
                else:
                    labels = None
                for sent_idx, sent in enumerate(sents):
                    new_datum = {
                        'uid':
                        make_uid(datum['img_id'], sents_cat, sent_idx)
                        if args.task_qa else None,
                        'img_id':
                        datum['img_id'],  # if not self.text_only else "",
                        'sent':
                        sent  #if not self.image_only else ""
                    }
                    if image_only:  # If we only use image, make sure one image only appears one time
                        if datum["img_id"] in record_img_id:
                            continue
                        record_img_id.add(datum["img_id"])

                    if labels is not None and args.task_qa:
                        new_datum['label'] = labels[sent_idx]

                    if self.task_nlvr2:
                        new_datum['match_label'] = datum["label"]
                        new_datum['img_id_1'] = datum["img_id_1"]

                    self.data.append(new_datum)

        if image_only:
            dataset_str = "image_only"
        elif text_only:
            dataset_str = "text_only"
        else:
            dataset_str = "vision and language"

        if self.image_only and args.get("screen_image", False):
            counter = 0
            from tqdm import tqdm
            _data = []
            for data_item in tqdm(self.data):
                img_id = data_item["img_id"]
                image_index = self.image_feature_dataset.ids_to_index[img_id]
                img_h = self.image_feature_dataset.h5_wh[image_index][1]
                img_w = self.image_feature_dataset.h5_wh[image_index][0]
                if img_h == 0 or img_w == 0:
                    counter += 1
                else:
                    _data.append(data_item)

            print(
                "Screened {} images with zero heights and weidths, {} in total"
                .format(counter, len(_data)))
            self.data = _data

        print("Use {} data in {} torch dataset, {}, limit_source {}".format(
            len(self.data), dataset_str, remaining_set, limit_source))

        if text_only:
            del self.image_feature_dataset

        if text_only or image_only:
            del self.raw_dataset.data
            del self.raw_dataset

        self.compress_memory = False
        if args.get("compress_memory", False):
            # Move some data to shared memory so the memory will not explode when using multi-process for data loading
            self.compress()
        print("\n\n\n")
Пример #24
0
 def __len__(self):
     if args.get("presegment_sentence",
                 False) and "sbu-captions-all.json" not in self.ann_file:
         return len(self.mapping)
     return len(self.passage_split)
Пример #25
0
    def __getitem__(self, item: int):
        if self.compress_memory:
            datum = self.decompress_getitem__(item)
        else:
            datum = self.data[item]

        uid = datum['uid']
        img_id = datum['img_id']
        sent = datum['sent'].lower()

        if not self.text_only:
            # Get image info
            if self.use_h5_file:
                image_index, obj_num, feats, boxes, img_h, img_w, obj_labels, obj_confs, attr_labels, attr_confs = self.image_feature_dataset[
                    img_id]
            else:
                img_info = self.imgid2img[img_id]
                obj_num = img_info['num_boxes']
                feats = img_info['features'].copy()
                boxes = img_info['boxes'].copy()
                obj_labels = img_info['objects_id'].copy()
                obj_confs = img_info['objects_conf'].copy()
                attr_labels = img_info['attrs_id'].copy()
                attr_confs = img_info['attrs_conf'].copy()
                assert obj_num == len(boxes) == len(feats)
                # Normalize the boxes (to 0 ~ 1)
                img_h, img_w = img_info['img_h'], img_info['img_w']
                #print(item, img_info, img_h, img_w)
            boxes = boxes.copy()
            boxes[:, (0, 2)] /= img_w
            boxes[:, (1, 3)] /= img_h

            np.testing.assert_array_less(boxes, 1 + 1e-5)
            np.testing.assert_array_less(-boxes, 0 + 1e-5)

        # If calculating the matched loss, replace the sentence with an sentence
        # corresponding to other image.
        is_matched = None
        if args.get('task_nlvr2', False):
            match_label = datum["match_label"]
            is_matched = match_label
            second_image_index, second_obj_num, second_feats, second_boxes, second_img_h, second_img_w, second_obj_labels, second_obj_confs, second_attr_labels, second_attr_confs = self.image_feature_dataset[
                datum["img_id_1"]]
            second_boxes = second_boxes.copy()
            second_boxes[:, (0, 2)] /= second_img_w
            second_boxes[:, (1, 3)] /= second_img_h
            np.testing.assert_array_less(second_boxes, 1 + 1e-5)
            np.testing.assert_array_less(-second_boxes, 0 + 1e-5)

            feats = np.concatenate((feats, second_feats))
            boxes = np.concatenate((boxes, second_boxes))
            obj_labels = np.concatenate((obj_labels, second_obj_labels))
            obj_confs = np.concatenate((obj_confs, second_obj_confs))
            #obj_confs=np.concatenate((obj_confs, second_obj_confs))
            attr_labels = np.concatenate((attr_labels, second_attr_labels))
            attr_confs = np.concatenate((attr_confs, second_attr_confs))

        elif self.task_matched:
            if random.random() < 0.5:
                is_matched = 0
                if self.compress_memory:
                    other_datum = self.decompress_getitem__(
                        random.randint(0,
                                       len(self.data) - 1))
                else:
                    other_datum = self.data[random.randint(
                        0,
                        len(self.data) - 1)]
                while other_datum['img_id'] == img_id:
                    if self.compress_memory:
                        other_datum = self.decompress_getitem__(
                            random.randint(0,
                                           len(self.data) - 1))
                    else:
                        other_datum = self.data[random.randint(
                            0,
                            len(self.data) - 1)]
                sent = other_datum['sent']
            else:
                is_matched = 1

        # Label, convert answer to id
        if 'label' in datum and args.task_qa:
            label = datum['label'].copy()
            for ans in list(label.keys()):
                label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(
                    ans)
        else:
            label = None

        if self.image_only:
            sent = None
        if self.text_only:
            feats = None
            boxes = None
            obj_labels = None
            obj_confs = None
            attr_labels = None
            attr_confs = None

        # Create target
        example = InputExample(uid,
                               sent, (feats, boxes), (obj_labels, obj_confs),
                               (attr_labels, attr_confs),
                               is_matched,
                               label,
                               use_visual_tag_flag=self.use_visual_tag_flag)

        #if args.get("faster_loading", False):
        return self.convert_example_to_features(example,
                                                args.get("max_seq_length", 20),
                                                self.tokenizer)
Пример #26
0
    def __getitem__(self, item):
        if self.use_mismatch_objective:
            i = 0
            max_seq_length = self.seq_len // 2  # We have two parts

            if args.get(
                    "presegment_sentence",
                    False) and "sbu-captions-all.json" not in self.ann_file:
                text_a_tokens, text_a_labels = self.retrieve_a_piece_preseged(
                    item, seq_len=max_seq_length)

                # First we take out some sentences
                if random.random() < 0.5:
                    # Take out our own
                    b_index = self.mapping[item]
                    text_b_tokens, text_b_labels = self.retrieve_a_piece_preseged(
                        b_index, seq_len=max_seq_length)
                    match = 1
                else:
                    random_index = random.randint(0, len(self) - 1)
                    while random_index == item:
                        random_index = random.randint(0, len(self) - 1)
                    text_b_tokens, text_b_labels = self.retrieve_a_piece_preseged(
                        random_index, seq_len=max_seq_length)
                    match = 0

            else:
                text_a_tokens, text_a_labels = self.retrieve_a_piece(
                    item, seq_len=max_seq_length)

                # First we take out some sentences
                if random.random() < 0.5:
                    # Take out our own
                    text_b_tokens, text_b_labels = self.retrieve_a_piece(
                        item, seq_len=max_seq_length)
                    match = 1
                else:
                    random_index = random.randint(0, len(self) - 1)
                    while random_index == item:
                        random_index = random.randint(0, len(self) - 1)
                    text_b_tokens, text_b_labels = self.retrieve_a_piece(
                        random_index, seq_len=max_seq_length)
                    match = 0

            text_a_ids = self.tokenizer.convert_tokens_to_ids(text_a_tokens)
            text_b_ids = self.tokenizer.convert_tokens_to_ids(text_b_tokens)

            example = InputExample(None, (text_a_tokens, text_b_tokens),
                                   (None, None), (None, None), (None, None),
                                   match,
                                   1,
                                   mlm_labels=(text_a_labels, text_b_labels),
                                   token_ids=(text_a_ids, text_b_ids),
                                   max_seq_len=self.seq_len + 3)
            if args.get("faster_loading", False):
                return self.convert_example_to_features(
                    example, self.seq_len + 3, self.tokenizer)

        raw = self.corpus[item]

        # tokenize
        tokens = self.tokenizer.basic_tokenizer.tokenize(raw.lower())

        if not self.do_no_fill:
            # add more tokens if len(tokens) < min_len
            _cur = (item + 1) % len(self.corpus)
            while len(tokens) < self.min_seq_len:
                _cur_tokens = self.tokenizer.basic_tokenizer.tokenize(
                    self.corpus[_cur])
                tokens.extend(_cur_tokens)
                _cur = (_cur + 1) % len(self.corpus)

        # masked language modeling
        tokens, mlm_labels = self.random_word_wwm(tokens)

        # convert token to its vocab id
        ids = self.tokenizer.convert_tokens_to_ids(tokens)

        # truncate
        if len(ids) > self.seq_len:
            ids = ids[:self.seq_len]
            mlm_labels = mlm_labels[:self.seq_len]

        example = InputExample(None,
                               tokens, (None, None), (None, None),
                               (None, None),
                               None,
                               1,
                               mlm_labels=mlm_labels,
                               token_ids=ids,
                               max_seq_len=self.seq_len)
        if args.get("faster_loading", False):
            return self.convert_example_to_features(
                example, args.get("max_seq_length", 20), self.tokenizer)

        return example
Пример #27
0
    def report(self):
        keys = list(self.counter_dict.keys())
        keys.sort()
        for key in keys:
            print("  {} : {:.7}".format(key, self.true_dict[key] / self.counter_dict[key]))
    
    def clean(self):
        self.counter_dict = defaultdict(float)
        self.true_dict = defaultdict(float)

from torch.utils.data import DataLoader
from torch.utils.data.dataloader import _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter


if args.get('random_seed', None):
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.random.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)


def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1, num_workers = 0, limit_source = [], restrict_source = None) -> DataTuple:
    # Decide which QA datasets would be used in pre-training.
    # Options: vqa, gqa, visual7w
    # Note: visual7w is a part of vgqa, we take the name here.
    qa_sets = args.qa_sets
    if qa_sets is not None:
        qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))

    # Build dataset, data loader, and evaluator.