示例#1
0
class LXMERTDataset:
    def __init__(self, splits: str, qa_sets=None):
        """
        :param splits: The data sources to be loaded
        :param qa_sets: if None, no action
                        o.w., only takes the answers appearing in these dsets
                              and remove all unlabeled data (MSCOCO captions)
        """
        self.name = splits
        self.sources = splits.split(',')

        # Loading datasets to data
        self.data = []
        for source in self.sources:
            self.data.extend(
                json.load(
                    open(BASEDIR +
                         "lxmert/caption_%s.json" % source.split('_')[1])))
        print("Load %d data from %s" % (len(self.data), self.name))

        # Create answer table according to the qa_sets
        self.answer_table = AnswerTable(qa_sets)
        print("Load an answer table of size %d." %
              (len(self.answer_table.ans2id_map())))

        # Modify the answers
        for datum in self.data:
            labelf = datum['labelf']
            for cat, labels in labelf.items():
                for label in labels:
                    for ans in list(label.keys()):
                        new_ans = self.answer_table.convert_ans(ans)
                        if self.answer_table.used(new_ans):
                            if ans != new_ans:
                                label[new_ans] = label.pop(ans)
                        else:
                            label.pop(ans)

    def __len__(self):
        return len(self.data)
示例#2
0
class PretrainingDataset(Dataset):
    def __init__(self, split='mscoco_mininval', topk=-1, data_out=['img'], verbose=True, args=None):

        self.data_out = data_out
        self.topk = topk
        self.verbose = verbose
        self.args = args

        self.datasets_dir = Path(self.args.datasets_dir)

        # Loading datasets to data
        self.sources = split.split(',')
        if self.verbose:
            print('Data sources: ', self.sources)

        self.answer_table = AnswerTable()
        # if self.verbose:
        print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map())))

        self.img_ids_to_source = {}

        data = []
        for img_source in self.sources:
            with open(self.datasets_dir.joinpath(f'data/lxmert/{img_source}.json')) as f:
                _data = json.load(f)
                if self.verbose:
                    print(f"Loaded {len(_data)} data from", img_source)
                # source_img_ids.append([d['img_id'] for d in _data])
                for datum in _data:
                    self.img_ids_to_source[datum['img_id']] = img_source
                    datum['img_source'] = img_source
                    datum['caption_only'] = args.caption_only
                    datum['clustering'] = args.clustering
                    datum['max_text_length'] = args.max_text_length
                    datum['qa'] = args.task_qa

                data.extend(_data)

        # Modify the answers
        if args.task_qa:
            for datum in data:
                labelf = datum['labelf']
                for _qa_source, labels in labelf.items():
                    for label in labels:
                        for ans in list(label.keys()):
                            new_ans = self.answer_table.convert_ans(ans)
                            if self.answer_table.used(new_ans):
                                if ans != new_ans:
                                    label[new_ans] = label.pop(ans)
                            else:
                                label.pop(ans)

        if self.topk > 0:
            data = data[:self.topk]
            if self.verbose:
                print(f"Use only {self.topk} data")

        if args.task_qa:
            self.evaluator = QAEvaluator(data)


        if args.clustering:
            clustering_dir = self.datasets_dir.joinpath('clustering')
            with open(clustering_dir.joinpath(f'{args.encoder}_{args.cluster_src}_mscoco_train_img_id_to_cluster_id_{args.n_centroids}_iter{args.n_iter}_d{args.feat_dim}_grid{args.grid_size}.pkl'), 'rb') as f:
                mscoco_train_img_id_to_cluster_id = pickle.load(f)
            with open(clustering_dir.joinpath(f'{args.encoder}_{args.cluster_src}_mscoco_valid_img_id_to_cluster_id_{args.n_centroids}_iter{args.n_iter}_d{args.feat_dim}_grid{args.grid_size}.pkl'), 'rb') as f:
                mscoco_valid_img_id_to_cluster_id = pickle.load(f)
            with open(clustering_dir.joinpath(f'{args.encoder}_{args.cluster_src}_vg_img_id_to_cluster_id_{args.n_centroids}_iter{args.n_iter}_d{args.feat_dim}_grid{args.grid_size}.pkl'), 'rb') as f:
                vg_img_id_to_cluster_id = pickle.load(f)

            self.data_source_to_cluster_data = {
                'mscoco_train': mscoco_train_img_id_to_cluster_id,
                'mscoco_minival': mscoco_valid_img_id_to_cluster_id,
                'mscoco_nominival': mscoco_valid_img_id_to_cluster_id,
                'vgnococo': vg_img_id_to_cluster_id
            }

        with Pool(8) as pool:
            if self.verbose:
                data = [datum for _data in tqdm(pool.imap(get_datum, data), total=len(data), ncols=100) for datum in _data]
            else:
                data = [datum for _data in pool.imap(get_datum, data) for datum in _data]

        if self.args.target_exact_feat or self.args.feed_exact_feat or self.args.target_obj_id:
            if args.grid_model:
                self.data_source_to_h5_path = {
                    'mscoco_train': self.datasets_dir.joinpath(f'COCO/features/{args.encoder}_train_grid{args.grid_size}.h5'),
                    'mscoco_minival': self.datasets_dir.joinpath(f'COCO/features/{args.encoder}_valid_grid{args.grid_size}.h5'),
                    'mscoco_nominival': self.datasets_dir.joinpath(f'COCO/features/{args.encoder}_valid_grid{args.grid_size}.h5'),
                    'vgnococo': self.datasets_dir.joinpath(f'VG/features/{args.encoder}_grid{args.grid_size}.h5'),
                }

            else:
                self.data_source_to_h5_path = {
                    'mscoco_train': self.datasets_dir.joinpath(f'COCO/features/maskrcnn_train_boxes36.h5'),
                    'mscoco_minival': self.datasets_dir.joinpath(f'COCO/features/maskrcnn_valid_boxes36.h5'),
                    'mscoco_nominival': self.datasets_dir.joinpath(f'COCO/features/maskrcnn_valid_boxes36.h5'),
                    'vgnococo': self.datasets_dir.joinpath(f'VG/features/maskrcnn_boxes36.h5'),
                }

            for source, path in self.data_source_to_h5_path.items():
                assert path.is_file(), (source, path)

            self.source_to_h5 = None

        self.data = data

        if args.vis_mask_COCO_only:
            COCO_data = []
            for datum in self.data:
                if datum['text_source'] == 'mscoco' and 'mscoco' in datum['img_source']:
                    COCO_data.append(datum)
            self.COCO_data = COCO_data
            if self.verbose:
                print('# COCO captions:', len(self.COCO_data))

        if self.verbose:
            if 'sent' not in self.data_out:
                print("# all images:", len(self.data))
            else:
                print("# all sentences:", len(self.data))

        self.grid_size = args.grid_size
        self.n_grids = args.n_grids
        if self.args.grid_model:
            self.boxes = box_position(args.grid_size)
        else:
            self.n_boxes = args.n_boxes
            self.boxes = None

        self.tokenizer = LxmertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True
        )

        self.max_text_length = args.max_text_length

        ###### Pretrainining Objective ######
        tasks = []
        if self.args.task_mask_lm:
            tasks.append('word_mask')
        if self.args.task_obj_predict:
            tasks.append('vis_mask')
        if self.args.task_matched:
            tasks.append('matched')
        if self.args.task_qa:
            tasks.append('qa')
        self.tasks = tasks

        if self.verbose:
            print('data_out:', self.data_out)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        out_dict = {}

        datum = self.data[idx]
        uid = datum['uid']
        out_dict['uid'] = uid
        out_dict['args'] = self.args

        ###### Image ######
        img_id = datum['img_id']
        if 'cluster_id' in self.data_out:
            # cluster_id = datum['cluster_id']
            img_id_to_cluster_id = self.data_source_to_cluster_data[datum['img_source']]
            cluster_id = img_id_to_cluster_id[img_id]
            assert cluster_id is not None, datum

            cluster_id = torch.from_numpy(cluster_id)
            out_dict['cluster_id'] = cluster_id

        if self.source_to_h5 is None:
            self.source_to_h5 = {}
            for source, path in self.data_source_to_h5_path.items():
                self.source_to_h5[source] = None
        source = self.img_ids_to_source[img_id]
        f = self.source_to_h5[source]
        if f is None:
            path = self.data_source_to_h5_path[source]
            f = h5py.File(path, 'r')
            self.source_to_h5[source] = f

        if 'feat' in self.data_out:
            if self.args.grid_model:
                feats = np.zeros(
                    shape=(self.grid_size, self.grid_size, self.args.feat_dim), dtype=np.float32)
                f[f'{img_id}/features'].read_direct(feats)
                feats = np.reshape(feats, (self.n_grids, self.args.feat_dim))
                feats = torch.from_numpy(feats)
            else:
                feats = np.zeros(shape=(self.n_boxes, self.args.feat_dim), dtype=np.float32)
                f[f'{img_id}/features'].read_direct(feats)
                feats = torch.from_numpy(feats)
            out_dict['vis_feats'] = feats

        if 'obj_id' in self.data_out:
            obj_id = np.zeros(shape=(self.n_boxes), dtype=int)
            f[f'{img_id}/obj_id'].read_direct(obj_id)
            obj_id = torch.from_numpy(obj_id)
            out_dict['obj_id'] = obj_id

        if self.args.grid_model:
            boxes = self.boxes
            boxes = torch.from_numpy(boxes)
        else:
            # Normalize the boxes (to 0 ~ 1)
            img_h = f[f'{img_id}/img_h'][()]
            img_w = f[f'{img_id}/img_w'][()]
            boxes = f[f'{img_id}/boxes'][()]
            boxes[:, (0, 2)] /= img_w
            boxes[:, (1, 3)] /= img_h
            # np.testing.assert_array_less(boxes, 1+1e-5)
            # np.testing.assert_array_less(boxes, 1+5e-2)
            np.testing.assert_array_less(-boxes, 0+1e-5)
            boxes = torch.from_numpy(boxes)

            boxes.clamp_(min=0.0, max=1.0)

        out_dict['boxes'] = boxes

        # if self.args.vis_sampling:
        #     sampled_idx = np.random.choice(self.n_grids, self.args.n_vis_sampling, replace=False)

        #     out_dict['boxes'] =  boxes[sampled_idx]
        #     if 'cluster_id' in self.data_out:
        #         out_dict['cluster_id'] = cluster_id[sampled_idx]
        #     if 'feat' in self.data_out:
        #         out_dict['vis_feats'] = feats[sampled_idx]

        ###### Text #####
        sent = datum['sent']
        # input_ids, n_tokens = text_process(sent)
        input_ids, n_tokens = datum['input_ids'], datum['n_tokens']
        input_ids = torch.LongTensor(input_ids)

        out_dict['sent'] = sent
        out_dict['input_ids'] = input_ids
        out_dict['n_tokens'] = n_tokens

        # Flip -> Img-Text not matched
        if 'matched' in self.data_out and random.random() < 0.5:
            other_datum = self.data[random.randint(0, len(self.data) - 1)]
            while img_id == other_datum['img_id']:
                other_datum = self.data[random.randint(0, len(self.data) - 1)]
            other_sent = other_datum['sent']
            # other_input_ids, other_n_tokens = text_process(other_sent)
            other_input_ids, other_n_tokens = other_datum['input_ids'], other_datum['n_tokens']

            other_input_ids = torch.LongTensor(other_input_ids)

            out_dict['matched_label'] = 0
            out_dict['other_sent'] = other_sent
            out_dict['other_input_ids'] = other_input_ids
            out_dict['other_n_tokens'] = other_n_tokens
        else:
            out_dict['matched_label'] = 1
            # out_dict['other_sent'] = sent
            # out_dict['other_input_ids'] = input_ids
            out_dict['other_n_tokens'] = n_tokens

        if self.args.task_qa:
            # Label, convert answer to id
            if 'label' in datum:
                label = datum['label'].copy()
                if len(label) > 0:

                    for ans in list(label.keys()):
                        label[self.answer_table.ans2id(ans)] = label.pop(ans)
                    keys, values = zip(*label.items())
                    # single answer
                    if len(keys) == 1:
                        ans = keys[0]
                    # multiple answers -> sample one answer
                    else:
                        value_sum = sum(values)
                        prob = [value / value_sum for value in values]
                        choice = np.random.multinomial(1, prob).argmax()
                        ans = keys[choice]
                else:
                    ans = -1
            else:
                ans = -1
            out_dict['ans'] = ans

        if self.args.vis_mask_predict:
            if self.args.square_mask:
                if self.args.vis_sampling:
                    grid_size = int(math.sqrt(self.args.n_vis_sampling))
                else:
                    grid_size = self.args.grid_size
                mask_size = random.randint(1, grid_size)
                vis_mask = torch.zeros(grid_size, grid_size)
                mask_position_h = random.randint(0, grid_size - mask_size)
                mask_position_w = random.randint(0, grid_size - mask_size)
                vis_mask[mask_position_h:mask_position_h + mask_size, mask_position_w:mask_position_w + mask_size] = 1
                out_dict['vis_mask'] = vis_mask.flatten()

            else:
                if self.args.vis_sampling:
                    total_idx = list(range(self.args.n_vis_sampling))
                    n_max_mask = self.args.n_vis_sampling
                else:
                    if self.args.grid_model:
                        total_idx = list(range(self.n_grids))
                        n_max_mask = self.n_grids
                    else:
                        total_idx = list(range(self.args.n_boxes))
                        n_max_mask = self.n_boxes
                n_masks = random.randint(1, n_max_mask)
                vis_mask = torch.zeros(n_max_mask)
                vis_mask_idx = np.random.choice(total_idx, n_masks, replace=False)
                vis_mask_idx = torch.from_numpy(vis_mask_idx)
                vis_mask[vis_mask_idx] = 1
                out_dict['vis_mask'] = vis_mask

            # if self.args.VMP_smart:
            #     if self.args.square_mask:
            #         if self.args.vis_sampling:
            #             grid_size = int(math.sqrt(self.args.n_vis_sampling))
            #         else:
            #             grid_size = self.args.grid_size
            #         mask_size = random.randint(1, grid_size)
            #         vis_mask = torch.zeros(grid_size, grid_size)
            #         mask_position_h = random.randint(0, grid_size - mask_size)
            #         mask_position_w = random.randint(0, grid_size - mask_size)
            #         vis_mask[mask_position_h:mask_position_h + mask_size, mask_position_w:mask_position_w + mask_size] = 1
            #         out_dict['vis_mask_2'] = vis_mask.flatten()

            #     else:
            #         if self.args.vis_sampling:
            #             total_idx = list(range(self.args.n_vis_sampling))
            #             n_max_mask = self.args.n_vis_sampling
            #         else:
            #             if self.args.grid_model:
            #                 total_idx = list(range(self.n_grids))
            #                 n_max_mask = self.n_grids
            #             else:
            #                 total_idx = list(range(self.args.n_boxes))
            #                 n_max_mask = self.n_boxes
            #         n_masks = random.randint(1, n_max_mask)
            #         vis_mask = torch.zeros(n_max_mask)
            #         vis_mask_idx = np.random.choice(total_idx, n_masks, replace=False)
            #         vis_mask_idx = torch.from_numpy(vis_mask_idx)
            #         vis_mask[vis_mask_idx] = 1
            #         out_dict['vis_mask_2'] = vis_mask
        else:
            if self.args.grid_model:
                if self.args.vis_sampling:
                    vis_mask = torch.bernoulli(
                        torch.full((self.args.n_vis_sampling,),  self.args.obj_mask_rate)).bool()
                else:
                    vis_mask = torch.bernoulli(
                        torch.full((self.n_grids,),  self.args.obj_mask_rate)).bool()
                out_dict['vis_mask'] = vis_mask
            else:
                vis_mask = torch.bernoulli(
                    torch.full((self.n_boxes,),  self.args.obj_mask_rate)).bool()
                out_dict['vis_mask'] = vis_mask


        if self.args.vis_mask_COCO_only:
            quotient = idx // len(self.COCO_data)
            if len(self.data) - quotient * len(self.COCO_data) < len(self.COCO_data):
                coco_idx = random.randint(0, len(self.COCO_data) - 1)
            else:
                coco_idx = idx % len(self.COCO_data)
            coco_datum = self.COCO_data[coco_idx]

            if self.args.vis_mask_COCO_only:
                assert coco_datum['text_source'] == 'mscoco'
            assert 'mscoco' in coco_datum['img_source']

            coco_input_ids, coco_n_tokens = coco_datum['input_ids'], coco_datum['n_tokens']
            coco_input_ids = torch.LongTensor(coco_input_ids)

            out_dict['COCO_input_ids'] = coco_input_ids
            out_dict['COCO_n_tokens'] = coco_n_tokens

            if 'cluster_id' in self.data_out:
                img_id = coco_datum['img_id']
                # cluster_id = datum['cluster_id']
                img_id_to_cluster_id = self.data_source_to_cluster_data[coco_datum['img_source']]
                cluster_id = img_id_to_cluster_id[img_id]
                assert cluster_id is not None, coco_datum

                cluster_id = torch.from_numpy(cluster_id)
                out_dict['COCO_cluster_id'] = cluster_id

        return out_dict