コード例 #1
0
class CLS3(Dataset):
    def __init__(self,
                 root_path=None,
                 image_set='train',
                 transform=None,
                 test_mode=False,
                 zip_mode=False,
                 cache_mode=False,
                 cache_db=True,
                 tokenizer=None,
                 pretrained_model_name=None,
                 add_image_as_a_box=False,
                 mask_size=(14, 14),
                 aspect_grouping=False,
                 **kwargs):
        """
        Visual Question Answering Dataset

        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to vcr dataset
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param mask_size: size of instance mask of each object
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(CLS3, self).__init__()
        cache_dir = False
        assert not cache_mode, 'currently not support cache mode!'

        categories = [
            '__background__', 'person', 'bicycle', 'car', 'motorcycle',
            'airplane', 'bus', 'train', 'truck', 'boat', 'trafficlight',
            'firehydrant', 'stopsign', 'parkingmeter', 'bench', 'bird', 'cat',
            'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
            'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
            'frisbee', 'skis', 'snowboard', 'sportsball', 'kite',
            'baseballbat', 'baseballglove', 'skateboard', 'surfboard',
            'tennisracket', 'bottle', 'wineglass', 'cup', 'fork', 'knife',
            'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
            'broccoli', 'carrot', 'hotdog', 'pizza', 'donut', 'cake', 'chair',
            'couch', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tv',
            'laptop', 'mouse', 'remote', 'keyboard', 'cellphone', 'microwave',
            'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
            'scissors', 'teddybear', 'hairdrier', 'toothbrush'
        ]
        self.category_to_idx = {c: i for i, c in enumerate(categories)}
        self.data_split = image_set  # HACK: reuse old parameter

        self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
        self.commaStrip = re.compile("(\d)(\,)(\d)")
        self.punct = [
            ';', r"/", '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_',
            '-', '>', '<', '@', '`', ',', '?', '!'
        ]

        self.test_mode = test_mode

        self.root_path = root_path

        self.box_bank = {}

        self.transform = transform
        self.zip_mode = zip_mode

        self.aspect_grouping = aspect_grouping
        self.add_image_as_a_box = add_image_as_a_box

        self.cache_dir = os.path.join(root_path, 'cache')
        # return_offsets_mapping
        model_name = 'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name
        self.fast_tokenizer = AutoTokenizer.from_pretrained(
            'bert-base-uncased',
            cache_dir=self.cache_dir,
            use_fast=True,
            return_offsets_mapping=True)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            model_name,
            cache_dir=self.cache_dir)
        self.max_txt_token = 128

        if zip_mode:
            self.zipreader = ZipReader()

        self.anno_aug = 'anno_aug' in kwargs
        self.database = self.load_annotations()
        self.use_img_box = True
        self.random_drop_tags = False
        # if self.aspect_grouping:
        #     self.group_ids = self.group_aspect(self.database)

    @property
    def data_names(self):
        if self.use_img_box:
            if self.test_mode:
                return [
                    'image',
                    'boxes',
                    'im_info',
                    'text',
                    'img_boxes',
                    'text_tags',
                    'id',
                ]
            else:
                return [
                    'image', 'boxes', 'im_info', 'text', 'img_boxes',
                    'text_tags', 'label', 'id'
                ]
        else:
            if self.test_mode:
                return [
                    'image',
                    'boxes',
                    'im_info',
                    'text',
                    'id',
                ]
            else:
                return ['image', 'boxes', 'im_info', 'text', 'label', 'id']

    @property
    def weights_by_class(self):
        labels = []
        num_per_class = collections.defaultdict(lambda: 0)
        for data in self.database:
            labels.append(data['label'])
            num_per_class[data['label']] += 1

        weight_per_class = {
            k: 1 / len(num_per_class) / v
            for k, v in num_per_class.items()
        }
        sampling_weight = [weight_per_class[label] for label in labels]
        return sampling_weight

    def clip_box_and_score(self, box_and_score):
        new_list = []
        for box_sc in box_and_score:
            cliped = {k: min(max(v, 0), 1) for k, v in box_sc.items()}
            new_list.append(cliped)
        return new_list

    def __getitem__(self, index):
        idb = self.database[index]

        # image, boxes, im_info
        image = self._load_image(os.path.join(self.root_path, idb['img']))
        w0, h0 = image.size

        if len(idb['boxes_and_score']) == 0:
            boxes = torch.as_tensor([[0.0, 0.0, w0 - 1, h0 - 1, 0]])
        else:
            boxes = torch.as_tensor([[
                box_sc['xmin'] * w0,
                box_sc['ymin'] * h0,
                box_sc['xmax'] * w0,
                box_sc['ymax'] * h0,
                box_sc['class_id'],
            ] for box_sc in idb['boxes_and_score']])
            if self.add_image_as_a_box:
                boxes = torch.cat(
                    (torch.as_tensor([[0.0, 0.0, w0 - 1, h0 - 1, 0]]), boxes),
                    dim=0)

        race_tags = [box_sc['race'] for box_sc in idb['boxes_and_score']]
        gender_tags = [box_sc['gender'] for box_sc in idb['boxes_and_score']]

        im_info = torch.tensor([w0, h0, 1.0, 1.0])
        # clamp boxes
        w = im_info[0].item()
        h = im_info[1].item()
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(min=0, max=w - 1)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(min=0, max=h - 1)

        flipped = False
        if self.transform is not None:
            image, boxes, _, im_info, flipped = self.transform(
                image, boxes, None, im_info, flipped)

        # question
        if 'token_id' not in idb:
            main_txt = idb['text']
            img_tags = [' '.join(des) for des in idb['partition_description']]
            img_tags_str = ''
            img_tags_part = []

            if not self.random_drop_tags or (self.random_drop_tags
                                             and random.random() > 0.5):
                for p, img_tag in enumerate(img_tags):
                    if img_tag:
                        append_str = img_tag + (' [SEP] ' if
                                                p != len(img_tags) - 1 else '')
                        img_tags_str += append_str
                        img_tags_part += [p] * len(append_str)

            person_tags_str = ''
            person_tags_part = []
            for j, (race, gend) in enumerate(zip(race_tags, gender_tags)):
                if race is not None:
                    is_last = not any(
                        [rt is not None for rt in race_tags[j + 1:]])
                    append_str = f"{race.replace('_', ' ')} {gend}"
                    append_str += "" if is_last else " [SEP] "
                    person_tags_str += append_str
                    person_tags_part += [
                        len(idb['image_partition']) * int(self.use_img_box) +
                        int(self.add_image_as_a_box) + j
                    ] * len(append_str)

            text_with_tag = f"{main_txt} [SEP] {img_tags_str} [SEP] {person_tags_str}"
            # print(f"[{index}] {text_with_tag}")
            result = self.fast_tokenizer(text_with_tag,
                                         return_offsets_mapping=True,
                                         add_special_tokens=False)
            token_id = result['input_ids']
            token_offset = result['offset_mapping']

            if self.use_img_box:
                text_partition = idb['text_char_partition_id']
                text_partition += [0] * len(
                    " [SEP] "
                ) + img_tags_part  # additinoal partition id for [SEP]
                text_partition += [0] * len(
                    " [SEP] "
                ) + person_tags_part  # additinoal partition id for [SEP]
                assert len(text_partition) == len(text_with_tag), \
                    F"{len(text_partition)} != {len(text_with_tag)}"

                token_tags = []
                for a, b in filter(lambda x: x[1] - x[0] > 0, token_offset):
                    char_tags = text_partition[a:b]
                    # print(a, b, char_tags)
                    cnt = collections.Counter(char_tags)
                    token_tags.append(cnt.most_common(1)[0][0])

                idb['text_tags'] = token_tags
                idb['image_partition'] = np.asarray(
                    idb['image_partition'], dtype=np.float32)[
                        ..., :4]  # HACK: remove det score from mmdet
            else:
                idb['text_tags'] = [0] * len(token_id)

            # token_id = self.tokenizer.convert_tokens_to_ids(text_tokens)
            if token_id[-1] == self.fast_tokenizer.sep_token_id:
                token_id = token_id[:-1]
                idb['text_tags'] = idb['text_tags'][:-1]

            if len(token_id) > self.max_txt_token:
                token_id = token_id[:self.max_txt_token]
                idb['text_tags'] = idb['text_tags'][:self.max_txt_token]

            idb['token_id'] = token_id
            assert len(idb['token_id']) == len(idb['text_tags'])
        else:
            token_id = idb['token_id']

        if self.use_img_box:
            if self.test_mode:
                return (
                    image,
                    boxes,
                    im_info,
                    token_id,
                    idb['image_partition'],
                    idb['text_tags'],
                    idb['id'],
                )
            else:
                # print([(self.answer_vocab[i], p.item()) for i, p in enumerate(label) if p.item() != 0])
                label = torch.Tensor([float(idb['label'])])
                return (
                    image,
                    boxes,
                    im_info,
                    token_id,
                    idb['image_partition'],
                    idb['text_tags'],
                    label,
                    idb['id'],
                )
        else:
            if self.test_mode:
                return image, boxes, im_info, token_id, idb['id']
            else:
                # print([(self.answer_vocab[i], p.item()) for i, p in enumerate(label) if p.item() != 0])
                label = torch.Tensor([float(idb['label'])])
                return image, boxes, im_info, token_id, label, idb['id']

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    def load_annotations(self):
        tic = time.time()
        img_name_to_annos = collections.defaultdict(list)

        test_json = os.path.join(self.root_path, 'test_unseen.entity.jsonl')
        dev_json = os.path.join(self.root_path, 'dev_seen.entity.jsonl')
        dev_train_json = os.path.join(self.root_path, 'dev_all.entity.jsonl')
        train_json = (os.path.join(self.root_path, 'train.entity.aug.jsonl')
                      if self.anno_aug else os.path.join(
                          self.root_path, 'train.entity.jsonl'))
        box_annos_json = os.path.join(self.root_path, 'box_annos.race.json')

        test_sample = []
        dev_sample = []
        train_sample = []
        dev_train_sample = []

        with open(train_json, mode='r') as f:
            for line in f.readlines():
                train_sample.append(json.loads(line))

        with open(dev_train_json, mode='r') as f:
            for line in f.readlines():
                dev_train_sample.append(json.loads(line))

        with open(test_json, mode='r') as f:
            for line in f.readlines():
                test_sample.append(json.loads(line))

        with open(dev_json, mode='r') as f:
            for line in f.readlines():
                dev_sample.append(json.loads(line))

        with open(box_annos_json, mode='r') as f:
            box_annos = json.load(f)

        sample_sets = []
        if self.data_split == 'train':
            sample_sets.append(train_sample)
        elif self.data_split == 'val':
            sample_sets.append(dev_sample)
        elif self.data_split == 'train+val':
            sample_sets.append(train_sample)
            sample_sets.append(dev_train_sample)
        elif self.data_split == 'test':
            sample_sets.append(test_sample)
        else:
            raise RuntimeError(f"Unknown dataset split: {self.data_split}")

        for sample_set in sample_sets:
            for sample in sample_set:
                img_name = os.path.basename(sample['img'])
                img_name_to_annos[img_name].append(sample)

        for box_anno in box_annos:
            img_name = box_anno['img_name']
            if img_name in img_name_to_annos:
                for sample in img_name_to_annos[img_name]:
                    sample.update(box_anno)

        print('Done (t={:.2f}s)'.format(time.time() - tic))

        flatten = []
        for annos in img_name_to_annos.values():
            flatten += annos
        return flatten

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def load_precomputed_boxes(self, box_file):
        if box_file in self.box_bank:
            return self.box_bank[box_file]
        else:
            in_data = {}
            with open(box_file, "r") as tsv_in_file:
                reader = csv.DictReader(tsv_in_file,
                                        delimiter='\t',
                                        fieldnames=FIELDNAMES)
                for item in reader:
                    item['image_id'] = int(item['image_id'])
                    item['image_h'] = int(item['image_h'])
                    item['image_w'] = int(item['image_w'])
                    item['num_boxes'] = int(item['num_boxes'])
                    for field in ([
                            'boxes', 'features'
                    ] if self.with_precomputed_visual_feat else ['boxes']):
                        item[field] = np.frombuffer(
                            base64.decodebytes(item[field].encode()),
                            dtype=np.float32).reshape((item['num_boxes'], -1))
                    in_data[item['image_id']] = item
            self.box_bank[box_file] = in_data
            return in_data

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)
コード例 #2
0
class VQA_CP(Dataset):
    def __init__(self, image_set, root_path, data_path, answer_vocab_file, use_imdb=True,
                 with_precomputed_visual_feat=False, boxes="36",
                 transform=None, test_mode=False,
                 zip_mode=False, cache_mode=False, cache_db=True, ignore_db_cache=True,
                 tokenizer=None, pretrained_model_name=None,
                 add_image_as_a_box=False, mask_size=(14, 14),
                 aspect_grouping=False, toy_dataset=False, toy_samples=128, **kwargs):
        """
        Visual Question Answering Dataset

        :param image_set: image folder name
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to vcr dataset
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param mask_size: size of instance mask of each object
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(VQA_CP, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'

        categories = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
                      'boat',
                      'trafficlight', 'firehydrant', 'stopsign', 'parkingmeter', 'bench', 'bird', 'cat', 'dog', 'horse',
                      'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
                      'suitcase', 'frisbee', 'skis', 'snowboard', 'sportsball', 'kite', 'baseballbat', 'baseballglove',
                      'skateboard', 'surfboard', 'tennisracket', 'bottle', 'wineglass', 'cup', 'fork', 'knife', 'spoon',
                      'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hotdog', 'pizza', 'donut',
                      'cake', 'chair', 'couch', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tv', 'laptop', 'mouse',
                      'remote', 'keyboard', 'cellphone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
                      'clock', 'vase', 'scissors', 'teddybear', 'hairdrier', 'toothbrush']
        vqa_question = {
            "train": "vqa/vqacp_v2_train_questions.json",
            "val": "vqa/vqacp_v2_test_questions.json",
        }
        vqa_annot = {
            "train": "vqa/vqacp_v2_train_annotations.json",
            "val": "vqa/vqacp_v2_test_annotations.json",
        }
        
        if boxes == "36":
            precomputed_boxes = {
                'train': ("vgbua_res101_precomputed", "{}_resnet101_faster_rcnn_genome_36"),
                'val': ("vgbua_res101_precomputed", "{}_resnet101_faster_rcnn_genome_36"),
            }
        elif boxes == "10-100ada":
            precomputed_boxes = {
                'train': ("vgbua_res101_precomputed", "{}_resnet101_faster_rcnn_genome"),
                'val': ("vgbua_res101_precomputed", "{}_resnet101_faster_rcnn_genome"),
            }
        else:
            raise ValueError("Not support boxes: {}!".format(boxes))

        self.coco_dataset = {
            "train2014": os.path.join(data_path, "annotations", "instances_train2014.json"),
            "val2014": os.path.join(data_path, "annotations", "instances_val2014.json"),
            "test-dev2015": os.path.join(data_path, "annotations", "image_info_test-dev2015.json"),
            "test2015": os.path.join(data_path, "annotations", "image_info_test2015.json"),
        }

        self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
        self.commaStrip = re.compile("(\d)(\,)(\d)")
        self.punct = [';', r"/", '[', ']', '"', '{', '}',
                      '(', ')', '=', '+', '\\', '_', '-',
                      '>', '<', '@', '`', ',', '?', '!']

        self.boxes = boxes
        self.test_mode = test_mode
        self.with_precomputed_visual_feat = with_precomputed_visual_feat
        self.category_to_idx = {c: i for i, c in enumerate(categories)}
        self.data_path = data_path
        self.root_path = root_path

        # load the answer vocab file: same as vqav2 dataset
        with open(answer_vocab_file, 'r', encoding='utf8') as f:
            self.answer_vocab = [w.lower().strip().strip('\r').strip('\n').strip('\r') for w in f.readlines()]
            self.answer_vocab = list(filter(lambda x: x != '', self.answer_vocab))
            self.answer_vocab = [self.processPunctuation(w) for w in self.answer_vocab]

        # The config.DATA.TRAIN_IMAGE_SET and config.DATA.VAL_IMAGE_SET have
        # a little different use here, it indicates the mode 'train' or 'val'
        self.image_sets = [iset.strip() for iset in image_set.split('+')]
        self.ann_files = [os.path.join(data_path, vqa_annot[iset]) for iset in self.image_sets] \
            if not self.test_mode else [None for iset in self.image_sets]
        self.q_files = [os.path.join(data_path, vqa_question[iset]) for iset in self.image_sets]

        self.precomputed_box_files = [
            os.path.join(data_path, precomputed_boxes[iset][0], precomputed_boxes[iset][1]) for iset in self.image_sets]

        self.box_bank = {}
        self.coco_datasets = [os.path.join(data_path, '{}', 'COCO_{}_{{:012d}}.jpg') for iset in self.image_sets]

        self.transform = transform
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        self.cache_dir = os.path.join(root_path, 'cache')
        self.add_image_as_a_box = add_image_as_a_box
        self.mask_size = mask_size

        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name,
            cache_dir=self.cache_dir)

        if zip_mode:
            self.zipreader = ZipReader()

        self.database = self.load_annotations()
        if self.aspect_grouping:
            self.group_ids = self.group_aspect(self.database)

        # toy dataset
        if toy_dataset:
            print(f"Using the toy dataset!! Total samples = {toy_samples}")
            self.database = self.database[:toy_samples]

    @property
    def data_names(self):
        if self.test_mode:
            return ['image', 'boxes', 'im_info', 'question']
        else:
            return ['image', 'boxes', 'im_info', 'question', 'label']

    def __getitem__(self, index):
        idb = self.database[index]

        # image, boxes, im_info
        boxes_data = self._load_json(idb['box_fn'])
        if self.with_precomputed_visual_feat:
            image = None
            w0, h0 = idb['width'], idb['height']

            boxes_features = torch.as_tensor(
                np.frombuffer(self.b64_decode(boxes_data['features']), dtype=np.float32).reshape((boxes_data['num_boxes'], -1))
            )
        else:
            image = self._load_image(idb['image_fn'])
            w0, h0 = image.size
        boxes = torch.as_tensor(
            np.frombuffer(self.b64_decode(boxes_data['boxes']), dtype=np.float32).reshape(
                (boxes_data['num_boxes'], -1))
        )

        if self.add_image_as_a_box:
            image_box = torch.as_tensor([[0.0, 0.0, w0 - 1, h0 - 1]])
            boxes = torch.cat((image_box, boxes), dim=0)
            if self.with_precomputed_visual_feat:
                if 'image_box_feature' in boxes_data:
                    image_box_feature = torch.as_tensor(
                        np.frombuffer(
                            self.b64_decode(boxes_data['image_box_feature']), dtype=np.float32
                        ).reshape((1, -1))
                    )
                else:
                    image_box_feature = boxes_features.mean(0, keepdim=True)
                boxes_features = torch.cat((image_box_feature, boxes_features), dim=0)
        im_info = torch.tensor([w0, h0, 1.0, 1.0])
        flipped = False
        if self.transform is not None:
            image, boxes, _, im_info, flipped = self.transform(image, boxes, None, im_info, flipped)

        # clamp boxes
        w = im_info[0].item()
        h = im_info[1].item()
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(min=0, max=w - 1)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(min=0, max=h - 1)

        # flip: 'left' -> 'right', 'right' -> 'left'
        q_tokens = self.tokenizer.tokenize(idb['question'])
        if flipped:
            q_tokens = self.flip_tokens(q_tokens, verbose=False)
        if not self.test_mode:
            answers = idb['answers']
            if flipped:
                answers_tokens = [a.split(' ') for a in answers]
                answers_tokens = [self.flip_tokens(a_toks, verbose=False) for a_toks in answers_tokens]
                answers = [' '.join(a_toks) for a_toks in answers_tokens]
            label = self.get_soft_target(answers)

        # question
        q_retokens = q_tokens
        q_ids = self.tokenizer.convert_tokens_to_ids(q_retokens)

        # concat box feature to box
        if self.with_precomputed_visual_feat:
            boxes = torch.cat((boxes, boxes_features), dim=-1)

        if self.test_mode:
            return image, boxes, im_info, q_ids
        else:
            # print([(self.answer_vocab[i], p.item()) for i, p in enumerate(label) if p.item() != 0])
            return image, boxes, im_info, q_ids, label

    @staticmethod
    def flip_tokens(tokens, verbose=True):
        changed = False
        tokens_new = [tok for tok in tokens]
        for i, tok in enumerate(tokens):
            if tok == 'left':
                tokens_new[i] = 'right'
                changed = True
            elif tok == 'right':
                tokens_new[i] = 'left'
                changed = True
        if verbose and changed:
            logging.info('[Tokens Flip] {} -> {}'.format(tokens, tokens_new))
        return tokens_new

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    def answer_to_ind(self, answer):
        if answer in self.answer_vocab:
            return self.answer_vocab.index(answer)
        else:
            return self.answer_vocab.index('<unk>')

    def get_soft_target(self, answers):

        soft_target = torch.zeros(len(self.answer_vocab), dtype=torch.float)
        answer_indices = [self.answer_to_ind(answer) for answer in answers]
        gt_answers = list(enumerate(answer_indices))
        unique_answers = set(answer_indices)

        for answer in unique_answers:
            accs = []
            for gt_answer in gt_answers:
                other_answers = [item for item in gt_answers if item != gt_answer]

                matching_answers = [item for item in other_answers if item[1] == answer]
                acc = min(1, float(len(matching_answers)) / 3)
                accs.append(acc)
            avg_acc = sum(accs) / len(accs)

            if answer != self.answer_vocab.index('<unk>'):
                soft_target[answer] = avg_acc

        return soft_target

    def processPunctuation(self, inText):

        if inText == '<unk>':
            return inText

        outText = inText
        for p in self.punct:
            if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
                outText = outText.replace(p, '')
            else:
                outText = outText.replace(p, ' ')
        outText = self.periodStrip.sub("",
                                       outText,
                                       re.UNICODE)
        return outText

    def load_annotations(self):
        tic = time.time()
        database = []
        db_cache_name = 'vqa_cp2_boxes{}_{}'.format(self.boxes, '+'.join(self.image_sets))
        if self.with_precomputed_visual_feat:
            db_cache_name += 'visualprecomp'
        if self.zip_mode:
            db_cache_name = db_cache_name + '_zipmode'
        if self.test_mode:
            db_cache_name = db_cache_name + '_testmode'
        db_cache_root = os.path.join(self.root_path, 'cache')
        db_cache_path = os.path.join(db_cache_root, '{}.pkl'.format(db_cache_name))

        if os.path.exists(db_cache_path):
            if not self.ignore_db_cache:
                # reading cached database
                print('cached database found in {}.'.format(db_cache_path))
                with open(db_cache_path, 'rb') as f:
                    print('loading cached database from {}...'.format(db_cache_path))
                    tic = time.time()
                    database = cPickle.load(f)
                    print('Done (t={:.2f}s)'.format(time.time() - tic))
                    return database
            else:
                print('cached database ignored.')

        # ignore or not find cached database, reload it from annotation file
        print('loading database of split {}...'.format('+'.join(self.image_sets)))
        tic = time.time()

        for ann_file, q_file, coco_path, box_file \
                in zip(self.ann_files, self.q_files, self.coco_datasets, self.precomputed_box_files):
            qs = self._load_json(q_file)
            anns = self._load_json(ann_file) if not self.test_mode else ([None] * len(qs))

            # we need to create 3 coco objects
            coco_train2014 = COCO(self.coco_dataset['train2014'])
            coco_val2014 = COCO(self.coco_dataset['val2014'])
            coco_test2015 = COCO(self.coco_dataset['test2015'])
            for ann, q in zip(anns, qs):
                if q['coco_split'] == 'train2014':
                    coco_obj = coco_train2014
                    box_dir = 'trainval2014'
                elif q['coco_split'] == 'val2014':
                    coco_obj = coco_val2014
                    box_dir = 'trainval2014'
                elif q['coco_split'] == 'test2015':
                    coco_obj = coco_test2015
                    box_dir = 'test2015'
                else:
                    raise ValueError("COCO split in question : {} not supported".format(q['coco_split']))

                idb = {'image_id': q['image_id'],
                       'image_fn': coco_path.format(q['coco_split'], q['coco_split'], q['image_id']),
                       'width': coco_obj.imgs[q['image_id']]['width'],
                       'height': coco_obj.imgs[q['image_id']]['height'],
                       'box_fn': os.path.join(box_file.format(box_dir), '{}.json'.format(q['image_id'])),
                       'question_id': q['question_id'],
                       'question': q['question'],
                       'answers': [a['answer'] for a in ann['answers']] if not self.test_mode else None,
                       'multiple_choice_answer': ann['multiple_choice_answer'] if not self.test_mode else None,
                       "question_type": ann['question_type'] if not self.test_mode else None,
                       "answer_type": ann['answer_type'] if not self.test_mode else None,
                       }
                database.append(idb)

        print('Done (t={:.2f}s)'.format(time.time() - tic))

        # cache database via cPickle
        if self.cache_db:
            print('caching database to {}...'.format(db_cache_path))
            tic = time.time()
            if not os.path.exists(db_cache_root):
                makedirsExist(db_cache_root)
            with open(db_cache_path, 'wb') as f:
                cPickle.dump(database, f)
            print('Done (t={:.2f}s)'.format(time.time() - tic))

        return database

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def load_precomputed_boxes(self, box_file):
        if box_file in self.box_bank:
            return self.box_bank[box_file]
        else:
            in_data = {}
            with open(box_file, "r") as tsv_in_file:
                reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES)
                for item in reader:
                    item['image_id'] = int(item['image_id'])
                    item['image_h'] = int(item['image_h'])
                    item['image_w'] = int(item['image_w'])
                    item['num_boxes'] = int(item['num_boxes'])
                    for field in (['boxes', 'features'] if self.with_precomputed_visual_feat else ['boxes']):
                        item[field] = np.frombuffer(base64.decodebytes(item[field].encode()),
                                                    dtype=np.float32).reshape((item['num_boxes'], -1))
                    in_data[item['image_id']] = item
            self.box_bank[box_file] = in_data
            return in_data

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)
コード例 #3
0
ファイル: multi30k_2018.py プロジェクト: phaedonmit/VL-BERT
class Multi30kDataset2018(Dataset):
    def __init__(self, ann_file, image_set, root_path, data_path, seq_len=64,
                 with_precomputed_visual_feat=False, mask_raw_pixels=True,
                 with_rel_task=True, with_mlm_task=False, with_mvrc_task=False,
                 transform=None, test_mode=False,
                 zip_mode=False, cache_mode=False, cache_db=False, ignore_db_cache=True,
                 tokenizer=None, pretrained_model_name=None,
                 add_image_as_a_box=False,
                 aspect_grouping=False, languages_used='first', MLT_vocab='bert-base-german-cased-vocab.txt', **kwargs):
        """
        Conceptual Captions Dataset

        :param ann_file: annotation jsonl file
        :param image_set: image folder name, e.g., 'vcr1images'
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to vcr dataset
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(Multi30kDataset2018, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'
        # TODO: need to remove this to allows testing
        # assert not test_mode

        annot = {'train': 'train_MLT_frcnn.json',
                 'val': 'val_MLT_frcnn.json',
                 'test2015': 'test_MLT_2018_renamed_frcnn.json'}

        self.seq_len = seq_len
        self.with_rel_task = with_rel_task
        self.with_mlm_task = with_mlm_task
        self.with_mvrc_task = with_mvrc_task
        self.data_path = data_path
        self.root_path = root_path
        self.ann_file = os.path.join(data_path, annot[image_set])
        self.with_precomputed_visual_feat = with_precomputed_visual_feat
        self.mask_raw_pixels = mask_raw_pixels
        self.image_set = image_set
        self.transform = transform
        self.test_mode = test_mode
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        #FM edit: added option for how many captions
        self.languages_used = languages_used
        self.cache_dir = os.path.join(root_path, 'cache')
        self.add_image_as_a_box = add_image_as_a_box
        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name,
            cache_dir=self.cache_dir)

        self.zipreader = ZipReader()

        # FM: Customise for multi30k dataset
        self.database = list(jsonlines.open(self.ann_file))
        if not self.zip_mode:
            for i, idb in enumerate(self.database):
                self.database[i]['frcnn'] = idb['frcnn'].replace('.zip@', '')\
                    .replace('.0', '').replace('.1', '').replace('.2', '').replace('.3', '')
                self.database[i]['image'] = idb['image'].replace('.zip@', '')


        if self.aspect_grouping:
            assert False, "not support aspect grouping currently!"
            self.group_ids = self.group_aspect(self.database)

        print('mask_raw_pixels: ', self.mask_raw_pixels)

        #FM: initialise vocabulary for output
        self.MLT_vocab_path = os.path.join(root_path, 'model/pretrained_model', MLT_vocab)
        self.MLT_vocab = []
        with open(self.MLT_vocab_path) as fp:
            for cnt, line in enumerate(fp):
                self.MLT_vocab.append(line.strip())


    @property
    def data_names(self):
        return ['image', 'boxes', 'im_info', 'text',
                'relationship_label', 'mlm_labels', 'mvrc_ops', 'mvrc_labels']

    def __getitem__(self, index):
        idb = self.database[index]

        # image data
        # IN ALL CASES: boxes and cls scores are available for each image
        frcnn_data = self._load_json(os.path.join(self.data_path, idb['frcnn']))
        boxes = np.frombuffer(self.b64_decode(frcnn_data['boxes']),
                              dtype=np.float32).reshape((frcnn_data['num_boxes'], -1))
        boxes_cls_scores = np.frombuffer(self.b64_decode(frcnn_data['classes']),
                                         dtype=np.float32).reshape((frcnn_data['num_boxes'], -1))
        boxes_max_conf = boxes_cls_scores.max(axis=1)
        inds = np.argsort(boxes_max_conf)[::-1]
        boxes = boxes[inds]
        boxes_cls_scores = boxes_cls_scores[inds]
        boxes = torch.as_tensor(boxes)

        # load precomputed features or the whole image depending on setup
        if self.with_precomputed_visual_feat:
            image = None
            w0, h0 = frcnn_data['image_w'], frcnn_data['image_h']
            boxes_features = np.frombuffer(self.b64_decode(frcnn_data['features']),
                                           dtype=np.float32).reshape((frcnn_data['num_boxes'], -1))
            boxes_features = boxes_features[inds]
            boxes_features = torch.as_tensor(boxes_features)
        else:
            try:
                image = self._load_image(os.path.join(self.data_path, idb['image']))
                w0, h0 = image.size
            except:
                print("Failed to load image {}, use zero image!".format(idb['image']))
                image = None
                w0, h0 = frcnn_data['image_w'], frcnn_data['image_h']

        # append whole image to tensor of boxes (used for all linguistic tokens)
        if self.add_image_as_a_box:
            image_box = torch.as_tensor([[0.0, 0.0, w0 - 1.0, h0 - 1.0]])
            boxes = torch.cat((image_box, boxes), dim=0)
            if self.with_precomputed_visual_feat:
                image_box_feat = boxes_features.mean(dim=0, keepdim=True)
                boxes_features = torch.cat((image_box_feat, boxes_features), dim=0)

        # transform
        im_info = torch.tensor([w0, h0, 1.0, 1.0, index])
        if self.transform is not None:
            image, boxes, _, im_info = self.transform(image, boxes, None, im_info)

        if image is None and (not self.with_precomputed_visual_feat):
            w = int(im_info[0].item())
            h = int(im_info[1].item())
            image = im_info.new_zeros((3, h, w), dtype=torch.float)

        # clamp boxes
        w = im_info[0].item()
        h = im_info[1].item()
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(min=0, max=w-1)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(min=0, max=h-1)

        # FM edit: remove - Task #1: Caption-Image Relationship Prediction
        word_en = idb['word_en']
        word_de = idb['word_de']
        caption_en = idb['caption_en']
        caption_de = idb['caption_de']
        
         # FM edit: add captions - tokenise words
        caption_tokens_en = self.tokenizer.tokenize(caption_en)
        # caption_tokens_de = self.tokenizer.tokenize(caption_de)
        word_tokens_en = self.tokenizer.tokenize(word_en)
        # word_tokens_de = self.tokenizer.tokenize(word_de)
        mlm_labels_en = [-1] * len(caption_tokens_en)
        mlm_labels_word_en = [-1] * len(caption_tokens_en)
        # mlm_labels_word_de = [-1] * len(caption_tokens_de)
        # mlm_labels_de = [-1] * len(caption_tokens_de)
        
        text_tokens = ['[CLS]'] + word_tokens_en + ['[SEP]'] + caption_tokens_en + ['[SEP]']
        mlm_labels = [-1] + mlm_labels_word_en + [-1] + mlm_labels_en + [-1]

        # relationship label - not used
        relationship_label = 1

        # Construct boxes
        mvrc_ops = [0] * boxes.shape[0]
        mvrc_labels = [np.zeros_like(boxes_cls_scores[0])] * boxes.shape[0]

        # store labels for masked regions
        mvrc_labels = np.stack(mvrc_labels, axis=0)

        # convert tokens to ids 
        text = self.tokenizer.convert_tokens_to_ids(text_tokens)

        if self.with_precomputed_visual_feat:
            boxes = torch.cat((boxes, boxes_features), dim=1)

        # truncate seq to max len
        if len(text) + len(boxes) > self.seq_len:
            text_len_keep = len(text)
            box_len_keep = len(boxes)
            while (text_len_keep + box_len_keep) > self.seq_len and (text_len_keep > 0) and (box_len_keep > 0):
                if box_len_keep > text_len_keep:
                    box_len_keep -= 1
                else:
                    text_len_keep -= 1
            if text_len_keep < 2:
                text_len_keep = 2
            if box_len_keep < 1:
                box_len_keep = 1
            boxes = boxes[:box_len_keep]
            text = text[:(text_len_keep - 1)] + [text[-1]]
            mlm_labels = mlm_labels[:(text_len_keep - 1)] + [mlm_labels[-1]]
            mvrc_ops = mvrc_ops[:box_len_keep]
            mvrc_labels = mvrc_labels[:box_len_keep]


        return image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)
コード例 #4
0
class RefCOCO(Dataset):
    def __init__(self,
                 image_set,
                 root_path,
                 data_path,
                 boxes='gt',
                 proposal_source='official',
                 transform=None,
                 test_mode=False,
                 zip_mode=False,
                 cache_mode=False,
                 cache_db=False,
                 ignore_db_cache=True,
                 tokenizer=None,
                 pretrained_model_name=None,
                 add_image_as_a_box=False,
                 mask_size=(14, 14),
                 aspect_grouping=False,
                 **kwargs):
        """
        RefCOCO+ Dataset

        :param image_set: image folder name
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to dataset
        :param boxes: boxes to use, 'gt' or 'proposal'
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param mask_size: size of instance mask of each object
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(RefCOCO, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'

        categories = [
            '__background__', 'person', 'bicycle', 'car', 'motorcycle',
            'airplane', 'bus', 'train', 'truck', 'boat', 'trafficlight',
            'firehydrant', 'stopsign', 'parkingmeter', 'bench', 'bird', 'cat',
            'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
            'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
            'frisbee', 'skis', 'snowboard', 'sportsball', 'kite',
            'baseballbat', 'baseballglove', 'skateboard', 'surfboard',
            'tennisracket', 'bottle', 'wineglass', 'cup', 'fork', 'knife',
            'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
            'broccoli', 'carrot', 'hotdog', 'pizza', 'donut', 'cake', 'chair',
            'couch', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tv',
            'laptop', 'mouse', 'remote', 'keyboard', 'cellphone', 'microwave',
            'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
            'scissors', 'teddybear', 'hairdrier', 'toothbrush'
        ]

        coco_annot_files = {
            "train2014": "annotations/instances_train2014.json",
            "val2014": "annotations/instances_val2014.json",
            "test2015": "annotations/image_info_test2015.json",
        }
        proposal_dets = 'refcoco+/proposal/res101_coco_minus_refer_notime_dets.json'
        proposal_masks = 'refcoco+/proposal/res101_coco_minus_refer_notime_masks.json'
        self.vg_proposal = ("vgbua_res101_precomputed",
                            "trainval2014_resnet101_faster_rcnn_genome")
        self.proposal_source = proposal_source
        self.boxes = boxes
        self.test_mode = test_mode
        self.category_to_idx = {c: i for i, c in enumerate(categories)}
        self.data_path = data_path
        self.root_path = root_path
        self.transform = transform
        self.image_sets = [iset.strip() for iset in image_set.split('+')]
        self.coco = COCO(annotation_file=os.path.join(
            data_path, coco_annot_files['train2014']))
        self.refer = REFER(data_path, dataset='refcoco+', splitBy='unc')
        self.refer_ids = []
        for iset in self.image_sets:
            self.refer_ids.extend(self.refer.getRefIds(split=iset))
        self.refs = self.refer.loadRefs(ref_ids=self.refer_ids)
        if 'proposal' in boxes:
            with open(os.path.join(data_path, proposal_dets), 'r') as f:
                proposal_list = json.load(f)
            self.proposals = {}
            for proposal in proposal_list:
                image_id = proposal['image_id']
                if image_id in self.proposals:
                    self.proposals[image_id].append(proposal['box'])
                else:
                    self.proposals[image_id] = [proposal['box']]
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        self.cache_dir = os.path.join(root_path, 'cache')
        self.add_image_as_a_box = add_image_as_a_box
        self.mask_size = mask_size
        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name,
            cache_dir=self.cache_dir)

        if zip_mode:
            self.zipreader = ZipReader()

        self.database = self.load_annotations()
        if self.aspect_grouping:
            self.group_ids = self.group_aspect(self.database)

    @property
    def data_names(self):
        if self.test_mode:
            return ['image', 'boxes', 'im_info', 'expression']
        else:
            return ['image', 'boxes', 'im_info', 'expression', 'label']

    def __getitem__(self, index):
        idb = self.database[index]

        # image related
        img_id = idb['image_id']
        image = self._load_image(idb['image_fn'])
        im_info = torch.as_tensor([idb['width'], idb['height'], 1.0, 1.0])
        if not self.test_mode:
            gt_box = torch.as_tensor(idb['gt_box'])
        flipped = False
        if self.boxes == 'gt':
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            boxes = []
            for ann in anns:
                x_, y_, w_, h_ = ann['bbox']
                boxes.append([x_, y_, x_ + w_, y_ + h_])
            boxes = torch.as_tensor(boxes)
        elif self.boxes == 'proposal':
            if self.proposal_source == 'official':
                boxes = torch.as_tensor(self.proposals[img_id])
                boxes[:, [2, 3]] += boxes[:, [0, 1]]
            elif self.proposal_source == 'vg':
                box_file = os.path.join(
                    self.data_path, self.vg_proposal[0],
                    '{0}.zip@/{0}'.format(self.vg_proposal[1]))
                boxes_fn = os.path.join(box_file,
                                        '{}.json'.format(idb['image_id']))
                boxes_data = self._load_json(boxes_fn)
                boxes = torch.as_tensor(
                    np.frombuffer(self.b64_decode(boxes_data['boxes']),
                                  dtype=np.float32).reshape(
                                      (boxes_data['num_boxes'], -1)))
            else:
                raise NotImplemented
        elif self.boxes == 'proposal+gt' or self.boxes == 'gt+proposal':
            if self.proposal_source == 'official':
                boxes = torch.as_tensor(self.proposals[img_id])
                boxes[:, [2, 3]] += boxes[:, [0, 1]]
            elif self.proposal_source == 'vg':
                box_file = os.path.join(
                    self.data_path, self.vg_proposal[0],
                    '{0}.zip@/{0}'.format(self.vg_proposal[1]))
                boxes_fn = os.path.join(box_file,
                                        '{}.json'.format(idb['image_id']))
                boxes_data = self._load_json(boxes_fn)
                boxes = torch.as_tensor(
                    np.frombuffer(self.b64_decode(boxes_data['boxes']),
                                  dtype=np.float32).reshape(
                                      (boxes_data['num_boxes'], -1)))
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            gt_boxes = []
            for ann in anns:
                x_, y_, w_, h_ = ann['bbox']
                gt_boxes.append([x_, y_, x_ + w_, y_ + h_])
            gt_boxes = torch.as_tensor(gt_boxes)
            boxes = torch.cat((boxes, gt_boxes), 0)
        else:
            raise NotImplemented

        if self.add_image_as_a_box:
            w0, h0 = im_info[0], im_info[1]
            image_box = torch.as_tensor([[0.0, 0.0, w0 - 1, h0 - 1]])
            boxes = torch.cat((image_box, boxes), dim=0)

        if self.transform is not None:
            if not self.test_mode:
                boxes = torch.cat((gt_box[None], boxes), 0)
            image, boxes, _, im_info, flipped = self.transform(
                image, boxes, None, im_info, flipped)
            if not self.test_mode:
                gt_box = boxes[0]
                boxes = boxes[1:]

        # clamp boxes
        w = im_info[0].item()
        h = im_info[1].item()
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(min=0, max=w - 1)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(min=0, max=h - 1)
        if not self.test_mode:
            gt_box[[0, 2]] = gt_box[[0, 2]].clamp(min=0, max=w - 1)
            gt_box[[1, 3]] = gt_box[[1, 3]].clamp(min=0, max=h - 1)

        # assign label to each box by its IoU with gt_box
        if not self.test_mode:
            boxes_ious = bbox_iou_py_vectorized(boxes, gt_box[None]).view(-1)
            label = (boxes_ious > 0.5).float()

        # expression
        exp_tokens = idb['tokens']
        exp_retokens = self.tokenizer.tokenize(' '.join(exp_tokens))
        if flipped:
            exp_retokens = self.flip_tokens(exp_retokens, verbose=True)
        exp_ids = self.tokenizer.convert_tokens_to_ids(exp_retokens)

        if self.test_mode:
            return image, boxes, im_info, exp_ids
        else:
            return image, boxes, im_info, exp_ids, label

    @staticmethod
    def flip_tokens(tokens, verbose=True):
        changed = False
        tokens_new = [tok for tok in tokens]
        for i, tok in enumerate(tokens):
            if tok == 'left':
                tokens_new[i] = 'right'
                changed = True
            elif tok == 'right':
                tokens_new[i] = 'left'
                changed = True
        if verbose and changed:
            logging.info('[Tokens Flip] {} -> {}'.format(tokens, tokens_new))
        return tokens_new

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    def load_annotations(self):
        tic = time.time()
        database = []
        db_cache_name = 'refcoco+_boxes_{}_{}'.format(
            self.boxes, '+'.join(self.image_sets))
        if self.zip_mode:
            db_cache_name = db_cache_name + '_zipmode'
        if self.test_mode:
            db_cache_name = db_cache_name + '_testmode'
        db_cache_root = os.path.join(self.root_path, 'cache')
        db_cache_path = os.path.join(db_cache_root,
                                     '{}.pkl'.format(db_cache_name))

        if os.path.exists(db_cache_path):
            if not self.ignore_db_cache:
                # reading cached database
                print('cached database found in {}.'.format(db_cache_path))
                with open(db_cache_path, 'rb') as f:
                    print('loading cached database from {}...'.format(
                        db_cache_path))
                    tic = time.time()
                    database = cPickle.load(f)
                    print('Done (t={:.2f}s)'.format(time.time() - tic))
                    return database
            else:
                print('cached database ignored.')

        # ignore or not find cached database, reload it from annotation file
        print('loading database of split {}...'.format('+'.join(
            self.image_sets)))
        tic = time.time()

        for ref_id, ref in zip(self.refer_ids, self.refs):
            iset = 'train2014'
            if not self.test_mode:
                gt_x, gt_y, gt_w, gt_h = self.refer.getRefBox(ref_id=ref_id)
            if self.zip_mode:
                image_fn = os.path.join(
                    self.data_path, iset + '.zip@/' + iset,
                    'COCO_{}_{:012d}.jpg'.format(iset, ref['image_id']))
            else:
                image_fn = os.path.join(
                    self.data_path, iset,
                    'COCO_{}_{:012d}.jpg'.format(iset, ref['image_id']))
            for sent in ref['sentences']:
                idb = {
                    'sent_id':
                    sent['sent_id'],
                    'ann_id':
                    ref['ann_id'],
                    'ref_id':
                    ref['ref_id'],
                    'image_id':
                    ref['image_id'],
                    'image_fn':
                    image_fn,
                    'width':
                    self.coco.imgs[ref['image_id']]['width'],
                    'height':
                    self.coco.imgs[ref['image_id']]['height'],
                    'raw':
                    sent['raw'],
                    'sent':
                    sent['sent'],
                    'tokens':
                    sent['tokens'],
                    'category_id':
                    ref['category_id'],
                    'gt_box': [gt_x, gt_y, gt_x + gt_w, gt_y +
                               gt_h] if not self.test_mode else None
                }
                database.append(idb)

        print('Done (t={:.2f}s)'.format(time.time() - tic))

        # cache database via cPickle
        if self.cache_db:
            print('caching database to {}...'.format(db_cache_path))
            tic = time.time()
            if not os.path.exists(db_cache_root):
                makedirsExist(db_cache_root)
            with open(db_cache_path, 'wb') as f:
                cPickle.dump(database, f)
            print('Done (t={:.2f}s)'.format(time.time() - tic))

        return database

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)
コード例 #5
0
ファイル: parallel_text.py プロジェクト: phaedonmit/VL-BERT
class ParallelTextDataset(Dataset):
    def __init__(self,
                 ann_file,
                 image_set,
                 root_path,
                 data_path,
                 seq_len=64,
                 with_precomputed_visual_feat=False,
                 mask_raw_pixels=True,
                 with_rel_task=True,
                 with_mlm_task=True,
                 with_mvrc_task=True,
                 transform=None,
                 test_mode=False,
                 zip_mode=False,
                 cache_mode=False,
                 cache_db=False,
                 ignore_db_cache=True,
                 tokenizer=None,
                 pretrained_model_name=None,
                 add_image_as_a_box=False,
                 aspect_grouping=False,
                 **kwargs):
        """
        Conceptual Captions Dataset

        :param ann_file: annotation jsonl file
        :param image_set: image folder name, e.g., 'vcr1images'
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to vcr dataset
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(ParallelTextDataset, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'
        assert not test_mode

        annot = {
            'train': 'train.json',
            'val': 'test.json',
            'test': 'test.json'
        }

        self.seq_len = seq_len
        self.with_rel_task = with_rel_task
        self.with_mlm_task = with_mlm_task
        self.with_mvrc_task = with_mvrc_task
        self.data_path = data_path
        self.root_path = root_path
        self.ann_file = os.path.join(data_path, annot[image_set])
        self.with_precomputed_visual_feat = with_precomputed_visual_feat
        self.mask_raw_pixels = mask_raw_pixels
        self.image_set = image_set
        self.transform = transform
        self.test_mode = test_mode
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        self.cache_dir = os.path.join(root_path, 'cache')
        self.add_image_as_a_box = add_image_as_a_box
        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name,
            cache_dir=self.cache_dir)

        self.zipreader = ZipReader()

        # FM: Customise for multi30k dataset
        self.database = list(jsonlines.open(self.ann_file))

        if self.aspect_grouping:
            assert False, "not support aspect grouping currently!"
            self.group_ids = self.group_aspect(self.database)

        print('mask_raw_pixels: ', self.mask_raw_pixels)

    @property
    def data_names(self):
        return ['text', 'relationship_label', 'mlm_labels']

    def __getitem__(self, index):
        idb = self.database[index]

        # Task #1: Caption-Image Relationship Prediction
        _p = random.random()
        if _p < 0.5 or (not self.with_rel_task):
            relationship_label = 1
            caption_en = idb['caption_en']
            caption_de = idb['caption_de']
        else:
            relationship_label = 0
            rand_index = random.randrange(0, len(self.database))
            while rand_index == index:
                rand_index = random.randrange(0, len(self.database))
            caption_en = self.database[rand_index]['caption_en']
            caption_de = self.database[rand_index]['caption_de']

        # Task #2: Masked Language Modeling - Adapted for two languages

        if self.with_mlm_task:
            # FM: removing joining of caption - split into two languages
            caption_tokens_en = self.tokenizer.basic_tokenizer.tokenize(
                caption_en)
            caption_tokens_en, mlm_labels_en = self.random_word_wwm(
                caption_tokens_en)
            caption_tokens_de = self.tokenizer.basic_tokenizer.tokenize(
                caption_de)
            caption_tokens_de, mlm_labels_de = self.random_word_wwm(
                caption_tokens_de)
        else:
            caption_tokens_en = self.tokenizer.tokenize(caption_en)
            caption_tokens_de = self.tokenizer.tokenize(caption_de)
            mlm_labels_en = [-1] * len(caption_tokens_en)
            mlm_labels_de = [-1] * len(caption_tokens_de)

        text_tokens = ['[CLS]'] + caption_tokens_en + [
            '[SEP]'
        ] + caption_tokens_de + ['[SEP]']
        mlm_labels = [-1] + mlm_labels_en + [-1] + mlm_labels_de + [-1]

        # convert tokens to ids
        text = self.tokenizer.convert_tokens_to_ids(text_tokens)

        # truncate seq to max len
        if len(text) > self.seq_len:
            text_len_keep = len(text)
            while (text_len_keep) > self.seq_len and (text_len_keep > 0):
                text_len_keep -= 1
            if text_len_keep < 2:
                text_len_keep = 2
            text = text[:(text_len_keep - 1)] + [text[-1]]

        return text, relationship_label, mlm_labels

    # def random_word(self, tokens):
    #     output_label = []
    #
    #     for i, token in enumerate(tokens):
    #         prob = random.random()
    #         # mask token with 15% probability
    #         if prob < 0.15:
    #             prob /= 0.15
    #
    #             # 80% randomly change token to mask token
    #             if prob < 0.8:
    #                 tokens[i] = "[MASK]"
    #
    #             # 10% randomly change token to random token
    #             elif prob < 0.9:
    #                 tokens[i] = random.choice(list(self.tokenizer.vocab.items()))[0]
    #
    #             # -> rest 10% randomly keep current token
    #
    #             # append current token to output (we will predict these later)
    #             try:
    #                 output_label.append(self.tokenizer.vocab[token])
    #             except KeyError:
    #                 # For unknown words (should not occur with BPE vocab)
    #                 output_label.append(self.tokenizer.vocab["[UNK]"])
    #                 logging.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
    #         else:
    #             # no masking token (will be ignored by loss function later)
    #             output_label.append(-1)
    #
    #     # if no word masked, random choose a word to mask
    #     if self.force_mask:
    #         if all([l_ == -1 for l_ in output_label]):
    #             choosed = random.randrange(0, len(output_label))
    #             output_label[choosed] = self.tokenizer.vocab[tokens[choosed]]
    #
    #     return tokens, output_label

    def random_word_wwm(self, tokens):
        output_tokens = []
        output_label = []

        for i, token in enumerate(tokens):
            sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token)
            prob = random.random()
            # mask token with 15% probability
            if prob < 0.15:
                prob /= 0.15

                # 80% randomly change token to mask token
                if prob < 0.8:
                    for sub_token in sub_tokens:
                        output_tokens.append("[MASK]")
                # 10% randomly change token to random token
                elif prob < 0.9:
                    for sub_token in sub_tokens:
                        output_tokens.append(
                            random.choice(list(self.tokenizer.vocab.keys())))
                        # -> rest 10% randomly keep current token
                else:
                    for sub_token in sub_tokens:
                        output_tokens.append(sub_token)

                        # append current token to output (we will predict these later)
                for sub_token in sub_tokens:
                    try:
                        output_label.append(self.tokenizer.vocab[sub_token])
                    except KeyError:
                        # For unknown words (should not occur with BPE vocab)
                        output_label.append(self.tokenizer.vocab["[UNK]"])
                        logging.warning(
                            "Cannot find sub_token '{}' in vocab. Using [UNK] insetad"
                            .format(sub_token))
            else:
                for sub_token in sub_tokens:
                    # no masking token (will be ignored by loss function later)
                    output_tokens.append(sub_token)
                    output_label.append(-1)

        ## if no word masked, random choose a word to mask
        # if all([l_ == -1 for l_ in output_label]):
        #    choosed = random.randrange(0, len(output_label))
        #    output_label[choosed] = self.tokenizer.vocab[tokens[choosed]]

        return output_tokens, output_label

    def random_mask_region(self, regions_cls_scores):
        num_regions, num_classes = regions_cls_scores.shape
        output_op = []
        output_label = []
        for k, cls_scores in enumerate(regions_cls_scores):
            prob = random.random()
            # mask region with 15% probability
            if prob < 0.15:
                prob /= 0.15

                if prob < 0.9:
                    # 90% randomly replace appearance feature by "MASK"
                    output_op.append(1)
                else:
                    # -> rest 10% randomly keep current appearance feature
                    output_op.append(0)

                # append class of region to output (we will predict these later)
                output_label.append(cls_scores)
            else:
                # no masking region (will be ignored by loss function later)
                output_op.append(0)
                output_label.append(np.zeros_like(cls_scores))

        # # if no region masked, random choose a region to mask
        # if all([op == 0 for op in output_op]):
        #     choosed = random.randrange(0, len(output_op))
        #     output_op[choosed] = 1
        #     output_label[choosed] = regions_cls_scores[choosed]

        return output_op, output_label

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)
コード例 #6
0
class COCOCaptionsDataset(Dataset):
    def __init__(self,
                 ann_file,
                 image_set,
                 root_path,
                 data_path,
                 seq_len=64,
                 with_precomputed_visual_feat=False,
                 mask_raw_pixels=True,
                 with_rel_task=True,
                 with_mlm_task=True,
                 with_mvrc_task=True,
                 transform=None,
                 test_mode=False,
                 zip_mode=False,
                 cache_mode=False,
                 cache_db=False,
                 ignore_db_cache=True,
                 tokenizer=None,
                 pretrained_model_name=None,
                 add_image_as_a_box=False,
                 aspect_grouping=False,
                 **kwargs):
        """
        Conceptual Captions Dataset

        :param ann_file: annotation jsonl file
        :param image_set: image folder name, e.g., 'vcr1images'
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to vcr dataset
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(COCOCaptionsDataset, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'
        assert not test_mode

        annot = {
            'train': 'annotations/captions_train2017.json',
            'val': 'annotations/captions_val2017.json'
        }
        annot_inst = {
            'train': 'annotations/instances_train2017.json',
            'val': 'annotations/instances_val2017.json'
        }
        if zip_mode:
            self.root = os.path.join(data_path,
                                     '{0}2017.zip@/{0}2017'.format(image_set))
        else:
            self.root = os.path.join(data_path, '{}2017'.format(image_set))

        self.seq_len = seq_len
        self.with_rel_task = with_rel_task
        self.with_mlm_task = with_mlm_task
        self.with_mvrc_task = with_mvrc_task
        self.data_path = data_path
        self.root_path = root_path
        self.ann_file = os.path.join(data_path, annot[image_set])
        self.ann_file_inst = os.path.join(data_path, annot_inst[image_set])
        self.with_precomputed_visual_feat = with_precomputed_visual_feat
        self.mask_raw_pixels = mask_raw_pixels
        self.image_set = image_set
        self.transform = transform
        self.test_mode = test_mode
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        self.cache_dir = os.path.join(root_path, 'cache')
        self.add_image_as_a_box = add_image_as_a_box
        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name,
            cache_dir=self.cache_dir)

        if self.zip_mode:
            self.zipreader = ZipReader()

        self.coco = COCO(self.ann_file)
        self.coco_inst = COCO(self.ann_file_inst)
        self.ids = list(sorted(self.coco.imgs.keys()))
        # filter images without detection annotations
        self.ids = [
            img_id for img_id in self.ids
            if len(self.coco_inst.getAnnIds(imgIds=img_id, iscrowd=None)) > 0
        ]

        self.json_category_id_to_contiguous_id = {
            v: i + 1
            for i, v in enumerate(self.coco_inst.getCatIds())
        }
        self.contiguous_category_id_to_json_id = {
            v: k
            for k, v in self.json_category_id_to_contiguous_id.items()
        }
        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}

        if self.aspect_grouping:
            assert False, "not support aspect grouping currently!"
            # self.group_ids = self.group_aspect(self.database)

        print('mask_raw_pixels: ', self.mask_raw_pixels)

    @property
    def data_names(self):
        return [
            'image', 'boxes', 'im_info', 'text', 'relationship_label',
            'mlm_labels', 'mvrc_ops', 'mvrc_labels'
        ]

    def __getitem__(self, index):
        img_id = self.ids[index]

        # image data
        # frcnn_data = self._load_json(os.path.join(self.data_path, idb['frcnn']))
        # boxes = np.frombuffer(self.b64_decode(frcnn_data['boxes']),
        #                       dtype=np.float32).reshape((frcnn_data['num_boxes'], -1))
        # boxes_cls_scores = np.frombuffer(self.b64_decode(frcnn_data['classes']),
        #                                  dtype=np.float32).reshape((frcnn_data['num_boxes'], -1))
        # boxes_max_conf = boxes_cls_scores.max(axis=1)
        # inds = np.argsort(boxes_max_conf)[::-1]
        # boxes = boxes[inds]
        # boxes_cls_scores = boxes_cls_scores[inds]
        # boxes = torch.as_tensor(boxes)
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        ann_ids_inst = self.coco_inst.getAnnIds(imgIds=img_id)
        anns_inst = self.coco_inst.loadAnns(ann_ids_inst)
        idb = anns[0]
        boxes = [ann_['bbox'] for ann_ in anns_inst]
        boxes = torch.as_tensor(boxes).reshape(-1, 4)
        TO_REMOVE = 1
        xmin, ymin, w, h = boxes.split(1, dim=-1)
        xmax = xmin + (w - TO_REMOVE).clamp(min=0)
        ymax = ymin + (h - TO_REMOVE).clamp(min=0)
        boxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
        boxes_cls_scores = boxes.new_zeros((boxes.shape[0], 81))
        classes = [ann["category_id"] for ann in anns_inst]
        classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
        for i, class_ in enumerate(classes):
            boxes_cls_scores[i, class_] = 1.0

        if self.with_precomputed_visual_feat:
            assert False
            # image = None
            # w0, h0 = frcnn_data['image_w'], frcnn_data['image_h']
            # boxes_features = np.frombuffer(self.b64_decode(frcnn_data['features']),
            #                                dtype=np.float32).reshape((frcnn_data['num_boxes'], -1))
            # boxes_features = boxes_features[inds]
            # boxes_features = torch.as_tensor(boxes_features)
        else:
            path = self.coco_inst.loadImgs(img_id)[0]['file_name']
            image = self._load_image(os.path.join(self.root, path))
            w0, h0 = image.size

        if self.add_image_as_a_box:
            image_box = torch.as_tensor([[0.0, 0.0, w0 - 1.0, h0 - 1.0]])
            boxes = torch.cat((image_box, boxes), dim=0)
            if self.with_precomputed_visual_feat:
                assert False
                # image_box_feat = boxes_features.mean(dim=0, keepdim=True)
                # boxes_features = torch.cat((image_box_feat, boxes_features), dim=0)

        # transform
        im_info = torch.tensor([w0, h0, 1.0, 1.0, index])
        if self.transform is not None:
            image, boxes, _, im_info = self.transform(image, boxes, None,
                                                      im_info)

        if image is None and (not self.with_precomputed_visual_feat):
            assert False
            # w = int(im_info[0].item())
            # h = int(im_info[1].item())
            # image = im_info.new_zeros((3, h, w), dtype=torch.float)

        # clamp boxes
        w = im_info[0].item()
        h = im_info[1].item()
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(min=0, max=w - 1)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(min=0, max=h - 1)

        # Task #1: Caption-Image Relationship Prediction
        _p = random.random()
        if _p < 0.5 or (not self.with_rel_task):
            relationship_label = 1
            caption = idb['caption']
        else:
            assert False
            relationship_label = 0
            rand_index = random.randrange(0, len(self.database))
            while rand_index == index:
                rand_index = random.randrange(0, len(self.database))
            caption = self.database[rand_index]['caption']

        assert isinstance(caption, str)

        # Task #2: Masked Language Modeling
        if self.with_mlm_task:
            caption_tokens = self.tokenizer.basic_tokenizer.tokenize(caption)
            caption_tokens, mlm_labels = self.random_word_wwm(caption_tokens)
        else:
            caption_tokens = self.tokenizer.tokenize(caption)
            mlm_labels = [-1] * len(caption_tokens)
        text_tokens = ['[CLS]'] + caption_tokens + ['[SEP]']
        mlm_labels = [-1] + mlm_labels + [-1]

        # Task #3: Masked Visual Region Classification
        if self.with_mvrc_task:
            if self.add_image_as_a_box:
                mvrc_ops, mvrc_labels = self.random_mask_region(
                    boxes_cls_scores)
                mvrc_ops = [0] + mvrc_ops
                mvrc_labels = [np.zeros_like(boxes_cls_scores[0])
                               ] + mvrc_labels
                num_real_boxes = boxes.shape[0] - 1
                num_masked_boxes = 0
                if self.with_precomputed_visual_feat:
                    assert False
                    # boxes_features[0] *= num_real_boxes
                    # for mvrc_op, box_feat in zip(mvrc_ops, boxes_features):
                    #     if mvrc_op == 1:
                    #         num_masked_boxes += 1
                    #         boxes_features[0] -= box_feat
                    # boxes_features[0] /= (num_real_boxes - num_masked_boxes + 1e-5)
            else:
                mvrc_ops, mvrc_labels = self.random_mask_region(
                    boxes_cls_scores)
            assert len(mvrc_ops) == boxes.shape[0], \
                "Error: mvrc_ops have length {}, expected {}!".format(len(mvrc_ops), boxes.shape[0])
            assert len(mvrc_labels) == boxes.shape[0], \
                "Error: mvrc_labels have length {}, expected {}!".format(len(mvrc_labels), boxes.shape[0])
        else:
            mvrc_ops = [0] * boxes.shape[0]
            mvrc_labels = [np.zeros_like(boxes_cls_scores[0])] * boxes.shape[0]

        # zero out pixels of masked RoI
        if (not self.with_precomputed_visual_feat) and self.mask_raw_pixels:
            for mvrc_op, box in zip(mvrc_ops, boxes):
                if mvrc_op == 1:
                    x1, y1, x2, y2 = box
                    image[:, int(y1):(int(y2) + 1), int(x1):(int(x2) + 1)] = 0

        mvrc_labels = np.stack(mvrc_labels, axis=0)

        text = self.tokenizer.convert_tokens_to_ids(text_tokens)

        if self.with_precomputed_visual_feat:
            assert False
            # boxes = torch.cat((boxes, boxes_features), dim=1)

        # truncate seq to max len
        if len(text) + len(boxes) > self.seq_len:
            text_len_keep = len(text)
            box_len_keep = len(boxes)
            while (text_len_keep + box_len_keep) > self.seq_len:
                if box_len_keep > text_len_keep:
                    box_len_keep -= 1
                else:
                    text_len_keep -= 1
            boxes = boxes[:box_len_keep]
            text = text[:text_len_keep]
            mlm_labels = mlm_labels[:text_len_keep]
            mvrc_ops = mvrc_ops[:box_len_keep]
            mvrc_labels = mvrc_labels[:box_len_keep]

        return image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels

    # def random_word(self, tokens):
    #     output_label = []
    #
    #     for i, token in enumerate(tokens):
    #         prob = random.random()
    #         # mask token with 15% probability
    #         if prob < 0.15:
    #             prob /= 0.15
    #
    #             # 80% randomly change token to mask token
    #             if prob < 0.8:
    #                 tokens[i] = "[MASK]"
    #
    #             # 10% randomly change token to random token
    #             elif prob < 0.9:
    #                 tokens[i] = random.choice(list(self.tokenizer.vocab.items()))[0]
    #
    #             # -> rest 10% randomly keep current token
    #
    #             # append current token to output (we will predict these later)
    #             try:
    #                 output_label.append(self.tokenizer.vocab[token])
    #             except KeyError:
    #                 # For unknown words (should not occur with BPE vocab)
    #                 output_label.append(self.tokenizer.vocab["[UNK]"])
    #                 logging.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
    #         else:
    #             # no masking token (will be ignored by loss function later)
    #             output_label.append(-1)
    #
    #     # if no word masked, random choose a word to mask
    #     if self.force_mask:
    #         if all([l_ == -1 for l_ in output_label]):
    #             choosed = random.randrange(0, len(output_label))
    #             output_label[choosed] = self.tokenizer.vocab[tokens[choosed]]
    #
    #     return tokens, output_label

    def random_word_wwm(self, tokens):
        output_tokens = []
        output_label = []

        for i, token in enumerate(tokens):
            sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token)
            prob = random.random()
            # mask token with 15% probability
            if prob < 0.15:
                prob /= 0.15

                # 80% randomly change token to mask token
                if prob < 0.8:
                    for sub_token in sub_tokens:
                        output_tokens.append("[MASK]")
                # 10% randomly change token to random token
                elif prob < 0.9:
                    for sub_token in sub_tokens:
                        output_tokens.append(
                            random.choice(list(self.tokenizer.vocab.keys())))
                        # -> rest 10% randomly keep current token
                else:
                    for sub_token in sub_tokens:
                        output_tokens.append(sub_token)

                        # append current token to output (we will predict these later)
                for sub_token in sub_tokens:
                    try:
                        output_label.append(self.tokenizer.vocab[sub_token])
                    except KeyError:
                        # For unknown words (should not occur with BPE vocab)
                        output_label.append(self.tokenizer.vocab["[UNK]"])
                        logging.warning(
                            "Cannot find sub_token '{}' in vocab. Using [UNK] insetad"
                            .format(sub_token))
            else:
                for sub_token in sub_tokens:
                    # no masking token (will be ignored by loss function later)
                    output_tokens.append(sub_token)
                    output_label.append(-1)

        ## if no word masked, random choose a word to mask
        # if all([l_ == -1 for l_ in output_label]):
        #    choosed = random.randrange(0, len(output_label))
        #    output_label[choosed] = self.tokenizer.vocab[tokens[choosed]]

        return output_tokens, output_label

    def random_mask_region(self, regions_cls_scores):
        num_regions, num_classes = regions_cls_scores.shape
        output_op = []
        output_label = []
        for k, cls_scores in enumerate(regions_cls_scores):
            prob = random.random()
            # mask region with 15% probability
            if prob < 0.15:
                prob /= 0.15

                if prob < 0.9:
                    # 90% randomly replace appearance feature by "MASK"
                    output_op.append(1)
                else:
                    # -> rest 10% randomly keep current appearance feature
                    output_op.append(0)

                # append class of region to output (we will predict these later)
                output_label.append(cls_scores)
            else:
                # no masking region (will be ignored by loss function later)
                output_op.append(0)
                output_label.append(np.zeros_like(cls_scores))

        # # if no region masked, random choose a region to mask
        # if all([op == 0 for op in output_op]):
        #     choosed = random.randrange(0, len(output_op))
        #     output_op[choosed] = 1
        #     output_label[choosed] = regions_cls_scores[choosed]

        return output_op, output_label

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def __len__(self):
        return len(self.ids)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)
コード例 #7
0
class VRep(Dataset):
    def __init__(self, image_set, root_path, data_path, boxes='gt', proposal_source='official',
                 transform=None, test_mode=False,
                 zip_mode=False, cache_mode=False, cache_db=False, ignore_db_cache=True,
                 tokenizer=None, pretrained_model_name=None,
                 add_image_as_a_box=False, mask_size=(14, 14),
                 aspect_grouping=False, **kwargs):
        """
        VREP Dataset

        :param image_set: image folder name
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to dataset
        :param boxes: boxes to use, 'gt' or 'proposal'
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param mask_size: size of instance mask of each object
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(VRep, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'
        self.data_json = 'obj_det_res.json'#'image_seg_test.json'#'obj_det_res.json'
        self.ref_json = 'ref_annotations.json'
        self.boxes = boxes
        self.refer = Refer()
        self.test_mode = test_mode
        self.data_path = data_path
        self.root_path = root_path
        self.transform = transform
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        self.cache_dir = os.path.join(root_path, 'cache')
        self.add_image_as_a_box = add_image_as_a_box
        self.mask_size = mask_size
        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name,
            cache_dir=self.cache_dir)

        if zip_mode:
            self.zipreader = ZipReader()

        self.database = self.load_annotations()
        if self.aspect_grouping:
            self.group_ids = self.group_aspect(self.database)

    @property
    def data_names(self):
        if self.test_mode:
            return ['image', 'boxes', 'im_info', 'expression']
        else:
            return ['image', 'boxes', 'im_info', 'expression', 'label']

    def __getitem__(self, index):
        idb = self.database[index]

	#print(idb)

        # image related
        img_id = idb['image_id']
        image = self._load_image(idb['image_fn'])
        im_info = torch.as_tensor([idb['width'], idb['height'], 1.0, 1.0])
        if not self.test_mode:
            gt_box = torch.as_tensor(idb['gt_box'])
        flipped = False
        bb = self._load_json(os.path.join(self.data_path, 'bb.json'))
        if self.boxes == 'gt':
            boxes = torch.as_tensor(bb[img_id])

        if self.add_image_as_a_box:
            w0, h0 = im_info[0], im_info[1]
            image_box = torch.as_tensor([[0.0, 0.0, w0 - 1, h0 - 1]])
            boxes = torch.cat((image_box, boxes), dim=0)

        if self.transform is not None:
            if not self.test_mode:
                boxes = torch.cat((gt_box[None], boxes), 0)
            image, boxes, _, im_info, flipped = self.transform(image, boxes, None, im_info, flipped)
            if not self.test_mode:
                gt_box = boxes[0]
                boxes = boxes[1:]

        # clamp boxes
        w = im_info[0].item()
        h = im_info[1].item()
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(min=0, max=w - 1)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(min=0, max=h - 1)
        if not self.test_mode:
            gt_box[[0, 2]] = gt_box[[0, 2]].clamp(min=0, max=w - 1)
            gt_box[[1, 3]] = gt_box[[1, 3]].clamp(min=0, max=h - 1)

        # assign label to each box by its IoU with gt_box
        if not self.test_mode:
            boxes_ious = bbox_iou_py_vectorized(boxes, gt_box[None]).view(-1)
            label = (boxes_ious > 0.5).float()

        # expression
        exp_tokens = idb['tokens']
        exp_retokens = self.tokenizer.tokenize(' '.join(exp_tokens))
        if flipped:
            exp_retokens = self.flip_tokens(exp_retokens, verbose=True)
        exp_ids = self.tokenizer.convert_tokens_to_ids(exp_retokens)

        if self.test_mode:
            return image, boxes, im_info, exp_ids
        else:
            return image, boxes, im_info, exp_ids, label

    @staticmethod
    def flip_tokens(tokens, verbose=True):
        changed = False
        tokens_new = [tok for tok in tokens]
        for i, tok in enumerate(tokens):
            if tok == 'left':
                tokens_new[i] = 'right'
                changed = True
            elif tok == 'right':
                tokens_new[i] = 'left'
                changed = True
        if verbose and changed:
            logging.info('[Tokens Flip] {} -> {}'.format(tokens, tokens_new))
        return tokens_new

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    def load_annotations(self):
        tic = time.time()
        database = []
        db_cache_name = 'vrep_boxes'#_{}_{}'.format(self.boxes, '+'.join(self.image_sets))
        if self.zip_mode:
            db_cache_name = db_cache_name + '_zipmode'
        if self.test_mode:
            db_cache_name = db_cache_name + '_testmode'
        db_cache_root = os.path.join(self.root_path, 'cache')
        db_cache_path = os.path.join(db_cache_root, '{}.pkl'.format(db_cache_name))
        dataset = self._load_json(os.path.join(self.data_path, self.data_json))
        ref = self._load_json(os.path.join(self.data_path, self.ref_json))
        if os.path.exists(db_cache_path):
            if not self.ignore_db_cache:
                # reading cached database
                print('cached database found in {}.'.format(db_cache_path))
                with open(db_cache_path, 'rb') as f:
                    print('loading cached database from {}...'.format(db_cache_path))
                    tic = time.time()
                    database = cPickle.load(f)
                    print('Done (t={:.2f}s)'.format(time.time() - tic))
                    return database
            else:
                print('cached database ignored.')

        # ignore or not find cached database, reload it from annotation file
        #print('loading database of split {}...'.format('+'.join(self.image_sets)))
        tic = time.time()

        refer_id = 0 
	
        for data_point in dataset['images']:
            iset  = 'full_images'
            image_name = data_point['file_name'].split('/')[3]
            if True:
            	for anno in data_point['annotations']:
                    if anno['id'] == data_point['ground_truth']:
                        gt_x, gt_y, gt_w, gt_h = anno['bbox']
            if self.zip_mode:
                image_fn = os.path.join(self.data_path, iset + '.zip@/' + iset, image_name)
            else:
                image_fn = os.path.join(self.data_path, iset, image_name)
            for sent in ref[image_name]:
                idb = {
                    #'sent_id': sent['sent_id'],
                    #'ann_id': ref['ann_id'],
                    'ref_id': refer_id,
                    'image_id': image_name,
                    'image_fn': image_fn,
                    'width': 1024,
                    'height': 576,
                    'raw': sent,
                    'sent': sent,
                    'tokens': self.tokenizer.tokenize(sent),
                    #'category_id': ref['category_id'],
                    'gt_box': [gt_x, gt_y, gt_x + gt_w, gt_y + gt_h] if not self.test_mode else None
                }
                self.refer.ref_id_to_box[refer_id] = [image_name, [gt_x, gt_y, gt_w, gt_h], sent]
                database.append(idb)
                refer_id += 1

        with open('./final_refer_testset', 'w') as f:
            json.dump(self.refer.ref_id_to_box, f)

        print('Done (t={:.2f}s)'.format(time.time() - tic))

        # cache database via cPickle
        if self.cache_db:
            print('caching database to {}...'.format(db_cache_path))
            tic = time.time()
            if not os.path.exists(db_cache_root):
                makedirsExist(db_cache_root)
            with open(db_cache_path, 'wb') as f:
                cPickle.dump(database, f)
            print('Done (t={:.2f}s)'.format(time.time() - tic))

        return database

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)
コード例 #8
0
class VCRDataset(Dataset):
    def __init__(self, ann_file, image_set, root_path, data_path, transform=None, task='Q2A', test_mode=False,
                 zip_mode=False, cache_mode=False, cache_db=False, ignore_db_cache=True,
                 basic_tokenizer=None, tokenizer=None, pretrained_model_name=None,
                 only_use_relevant_dets=False, add_image_as_a_box=False, mask_size=(14, 14),
                 aspect_grouping=False, basic_align=False, qa2r_noq=False, qa2r_aug=False,
                 seq_len=64,
                 **kwargs):
        """
        Visual Commonsense Reasoning Dataset

        :param ann_file: annotation jsonl file
        :param image_set: image folder name, e.g., 'vcr1images'
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to vcr dataset
        :param transform: transform
        :param task: 'Q2A' means question to answer, 'QA2R' means question and answer to rationale,
                     'Q2AR' means question to answer and rationale
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param only_use_relevant_dets: filter out detections not used in query and response
        :param add_image_as_a_box: add whole image as a box
        :param mask_size: size of instance mask of each object
        :param aspect_grouping: whether to group images via their aspect
        :param basic_align: align to tokens retokenized by basic_tokenizer
        :param qa2r_noq: in QA->R, the query contains only the correct answer, without question
        :param qa2r_aug: in QA->R, whether to augment choices to include those with wrong answer in query
        :param kwargs:
        """
        super(VCRDataset, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'
        assert task in ['Q2A', 'QA2R', 'Q2AR'] , 'not support task {}'.format(task)
        assert not qa2r_aug, "Not implemented!"

        self.qa2r_noq = qa2r_noq
        self.qa2r_aug = qa2r_aug

        self.seq_len = seq_len

        categories = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
                      'trafficlight', 'firehydrant', 'stopsign', 'parkingmeter', 'bench', 'bird', 'cat', 'dog', 'horse',
                      'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
                      'suitcase', 'frisbee', 'skis', 'snowboard', 'sportsball', 'kite', 'baseballbat', 'baseballglove',
                      'skateboard', 'surfboard', 'tennisracket', 'bottle', 'wineglass', 'cup', 'fork', 'knife', 'spoon',
                      'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hotdog', 'pizza', 'donut',
                      'cake', 'chair', 'couch', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tv', 'laptop', 'mouse',
                      'remote', 'keyboard', 'cellphone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
                      'clock', 'vase', 'scissors', 'teddybear', 'hairdrier', 'toothbrush']
        self.category_to_idx = {c: i for i, c in enumerate(categories)}
        self.data_path = data_path
        self.root_path = root_path
        self.ann_file = os.path.join(data_path, ann_file)
        self.image_set = image_set
        self.transform = transform
        self.task = task
        self.test_mode = test_mode
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        self.basic_align = basic_align
        print('Dataset Basic Align: {}'.format(self.basic_align))
        self.cache_dir = os.path.join(root_path, 'cache')
        self.only_use_relevant_dets = only_use_relevant_dets
        self.add_image_as_a_box = add_image_as_a_box
        self.mask_size = mask_size
        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.basic_tokenizer = basic_tokenizer if basic_tokenizer is not None \
            else BasicTokenizer(do_lower_case=True)
        if tokenizer is None:
            if pretrained_model_name is None:
                pretrained_model_name = 'bert-base-uncased'
            if 'roberta' in pretrained_model_name:
                tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name, cache_dir=self.cache_dir)
            else:
                tokenizer = BertTokenizer.from_pretrained(pretrained_model_name, cache_dir=self.cache_dir)
        self.tokenizer = tokenizer

        if zip_mode:
            self.zipreader = ZipReader()

        self.database = self.load_annotations(self.ann_file)
        if self.aspect_grouping:
            assert False, "Not support aspect grouping now!"
            self.group_ids = self.group_aspect(self.database)

        self.person_name_id = 0

    def load_annotations(self, ann_file):
        tic = time.time()
        database = []
        db_cache_name = 'vcr_nometa_{}_{}_{}'.format(self.task, self.image_set, os.path.basename(ann_file)[:-len('.jsonl')])
        if self.only_use_relevant_dets:
            db_cache_name = db_cache_name + '_only_relevant_dets'
        if self.zip_mode:
            db_cache_name = db_cache_name + '_zipped'
        db_cache_root = os.path.join(self.root_path, 'cache')
        db_cache_path = os.path.join(db_cache_root, '{}.pkl'.format(db_cache_name))

        if os.path.exists(db_cache_path):
            if not self.ignore_db_cache:
                # reading cached database
                print('cached database found in {}.'.format(db_cache_path))
                with open(db_cache_path, 'rb') as f:
                    print('loading cached database from {}...'.format(db_cache_path))
                    tic = time.time()
                    database = cPickle.load(f)
                    print('Done (t={:.2f}s)'.format(time.time() - tic))
                    return database
            else:
                print('cached database ignored.')

        # ignore or not find cached database, reload it from annotation file
        print('loading database from {}...'.format(ann_file))
        tic = time.time()

        with jsonlines.open(ann_file) as reader:
            for ann in reader:
                if self.zip_mode:
                    img_fn = os.path.join(self.data_path, self.image_set + '.zip@/' + self.image_set, ann['img_fn'])
                    metadata_fn = os.path.join(self.data_path, self.image_set + '.zip@/' + self.image_set, ann['metadata_fn'])
                else:
                    img_fn = os.path.join(self.data_path, self.image_set, ann['img_fn'])
                    metadata_fn = os.path.join(self.data_path, self.image_set, ann['metadata_fn'])

                db_i = {
                    'annot_id': ann['annot_id'],
                    'objects': ann['objects'],
                    'img_fn': img_fn,
                    'metadata_fn': metadata_fn,
                    'question': ann['question'],
                    'answer_choices': ann['answer_choices'],
                    'answer_label': ann['answer_label'] if not self.test_mode else None,
                    'rationale_choices': ann['rationale_choices'],
                    'rationale_label': ann['rationale_label'] if not self.test_mode else None,
                }
                database.append(db_i)
        print('Done (t={:.2f}s)'.format(time.time() - tic))

        # cache database via cPickle
        if self.cache_db:
            print('caching database to {}...'.format(db_cache_path))
            tic = time.time()
            if not os.path.exists(db_cache_root):
                makedirsExist(db_cache_root)
            with open(db_cache_path, 'wb') as f:
                cPickle.dump(database, f)
            print('Done (t={:.2f}s)'.format(time.time() - tic))

        return database

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def retokenize_and_convert_to_ids_with_tag(self, tokens, objects_replace_name, non_obj_tag=-1):
        parsed_tokens = []
        tags = []
        align_ids = []
        raw = []
        align_id = 0
        for mixed_token in tokens:
            if isinstance(mixed_token, list):
                tokens = [objects_replace_name[o] for o in mixed_token]
                retokenized_tokens = self.tokenizer.tokenize(tokens[0])
                raw.append(tokens[0])
                tags.extend([mixed_token[0] + non_obj_tag + 1 for _ in retokenized_tokens])
                align_ids.extend([align_id for _ in retokenized_tokens])
                align_id += 1
                for token, o in zip(tokens[1:], mixed_token[1:]):
                    retokenized_tokens.append('and')
                    tags.append(non_obj_tag)
                    align_ids.append(align_id)
                    align_id += 1
                    re_tokens = self.tokenizer.tokenize(token)
                    retokenized_tokens.extend(re_tokens)
                    tags.extend([o + non_obj_tag + 1 for _ in re_tokens])
                    align_ids.extend([align_id for _ in re_tokens])
                    align_id += 1
                    raw.extend(['and', token])
                parsed_tokens.extend(retokenized_tokens)
            else:
                if self.basic_align:
                    # basic align
                    basic_tokens = self.basic_tokenizer.tokenize(mixed_token)
                    raw.extend(basic_tokens)
                    for t in basic_tokens:
                        retokenized_tokens = self.tokenizer.tokenize(t)
                        parsed_tokens.extend(retokenized_tokens)
                        align_ids.extend([align_id for _ in retokenized_tokens])
                        tags.extend([non_obj_tag for _ in retokenized_tokens])
                        align_id += 1
                else:
                    # fully align to original tokens
                    raw.append(mixed_token)
                    retokenized_tokens = self.tokenizer.tokenize(mixed_token)
                    parsed_tokens.extend(retokenized_tokens)
                    align_ids.extend([align_id for _ in retokenized_tokens])
                    tags.extend([non_obj_tag for _ in retokenized_tokens])
                    align_id += 1
        ids = self.tokenizer.convert_tokens_to_ids(parsed_tokens)
        ids_with_tag = list(zip(ids, tags, align_ids))
        
        return ids_with_tag, raw

    @staticmethod
    def keep_only_relevant_dets(question, answer_choices, rationale_choices):
        dets_to_use = []
        for i, tok in enumerate(question):
            if isinstance(tok, list):
                for j, o in enumerate(tok):
                    if o not in dets_to_use:
                        dets_to_use.append(o)
                    question[i][j] = dets_to_use.index(o)
        if answer_choices is not None:
            for n, answer in enumerate(answer_choices):
                for i, tok in enumerate(answer):
                    if isinstance(tok, list):
                        for j, o in enumerate(tok):
                            if o not in dets_to_use:
                                dets_to_use.append(o)
                            answer_choices[n][i][j] = dets_to_use.index(o)
        if rationale_choices is not None:
            for n, rationale in enumerate(rationale_choices):
                for i, tok in enumerate(rationale):
                    if isinstance(tok, list):
                        for j, o in enumerate(tok):
                            if o not in dets_to_use:
                                dets_to_use.append(o)
                            rationale_choices[n][i][j] = dets_to_use.index(o)

        return dets_to_use, question, answer_choices, rationale_choices

    def __getitem__(self, index):
        # self.person_name_id = 0
        idb = deepcopy(self.database[index])

        metadata = self._load_json(idb['metadata_fn'])
        idb['boxes'] = metadata['boxes']
        idb['segms'] = metadata['segms']
        # idb['width'] = metadata['width']
        # idb['height'] = metadata['height']
        if self.only_use_relevant_dets:
            dets_to_use, idb['question'], idb['answer_choices'], idb['rationale_choices'] = \
                self.keep_only_relevant_dets(idb['question'],
                                             idb['answer_choices'],
                                             idb['rationale_choices'] if not self.task == 'Q2A' else None)
            idb['objects'] = [idb['objects'][i] for i in dets_to_use]
            idb['boxes'] = [idb['boxes'][i] for i in dets_to_use]
            idb['segms'] = [idb['segms'][i] for i in dets_to_use]
        objects_replace_name = []
        for o in idb['objects']:
            if o == 'person':
                objects_replace_name.append(GENDER_NEUTRAL_NAMES[self.person_name_id])
                self.person_name_id = (self.person_name_id + 1) % len(GENDER_NEUTRAL_NAMES)
            else:
                objects_replace_name.append(o)

        non_obj_tag = 0 if self.add_image_as_a_box else -1
        idb['question'] = self.retokenize_and_convert_to_ids_with_tag(idb['question'],
                                                                      objects_replace_name=objects_replace_name,
                                                                      non_obj_tag=non_obj_tag)

        idb['answer_choices'] = [self.retokenize_and_convert_to_ids_with_tag(answer,
                                                                             objects_replace_name=objects_replace_name,
                                                                             non_obj_tag=non_obj_tag)
                                 for answer in idb['answer_choices']]

        idb['rationale_choices'] = [self.retokenize_and_convert_to_ids_with_tag(rationale,
                                                                                objects_replace_name=objects_replace_name,
                                                                                non_obj_tag=non_obj_tag)
                                    for rationale in idb['rationale_choices']] if not self.task == 'Q2A' else None

        # truncate text to seq_len
        if self.task == 'Q2A':
            q = idb['question'][0]
            for a, a_raw in idb['answer_choices']:
                while len(q) + len(a) > self.seq_len:
                    if len(a) > len(q):
                        a.pop()
                    else:
                        q.pop()
        elif self.task == 'QA2R':
            if not self.test_mode:
                q = idb['question'][0]
                a = idb['answer_choices'][idb['answer_label']][0]
                for r, r_raw in idb['rationale_choices']:
                    while len(q) + len(a) + len(r) > self.seq_len:
                        if len(r) > (len(q) + len(a)):
                            r.pop()
                        elif len(q) > 1:
                            q.pop()
                        else:
                            a.pop()
        else:
            raise NotImplemented

        image = self._load_image(idb['img_fn'])
        w0, h0 = image.size
        objects = idb['objects']

        # extract bounding boxes and instance masks in metadata
        boxes = torch.zeros((len(objects), 6))
        masks = torch.zeros((len(objects), *self.mask_size))
        if len(objects) > 0:
            boxes[:, :5] = torch.tensor(idb['boxes'])
            boxes[:, 5] = torch.tensor([self.category_to_idx[obj] for obj in objects])
            for i in range(len(objects)):
                seg_polys = [torch.as_tensor(seg) for seg in idb['segms'][i]]
                masks[i] = generate_instance_mask(seg_polys, idb['boxes'][i], mask_size=self.mask_size,
                                                  dtype=torch.float32, copy=False)
        if self.add_image_as_a_box:
            image_box = torch.as_tensor([[0, 0, w0 - 1, h0 - 1, 1.0, 0]])
            image_mask = torch.ones((1, *self.mask_size))
            boxes = torch.cat((image_box, boxes), dim=0)
            masks = torch.cat((image_mask, masks), dim=0)

        question, question_raw = idb['question']
        question_align_matrix = get_align_matrix([w[2] for w in question])
        answer_choices, answer_choices_raw = zip(*idb['answer_choices'])
        answer_choices = list(answer_choices)
        answer_align_matrix = [get_align_matrix([w[2] for w in a]) for a in answer_choices]
        answer_label = torch.as_tensor(idb['answer_label']) if not self.test_mode else None
        if not self.task == 'Q2A':
            rationale_choices = [r[0] for r in idb['rationale_choices']]
            rationale_align_matrix = [get_align_matrix([w[2] for w in r]) for r in rationale_choices]
            rationale_label = torch.as_tensor(idb['rationale_label']) if not self.test_mode else None

        # transform
        im_info = torch.tensor([w0, h0, 1.0, 1.0, index])
        if self.transform is not None:
            image, boxes, masks, im_info = self.transform(image, boxes, masks, im_info)

        # clamp boxes
        w = im_info[0].item()
        h = im_info[1].item()
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(min=0, max=w - 1)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(min=0, max=h - 1)

        if self.task == 'Q2AR':
            if not self.test_mode:
                outputs = (image, boxes, masks,
                           question, question_align_matrix,
                           answer_choices, answer_align_matrix, answer_label,
                           rationale_choices, rationale_align_matrix, rationale_label,
                           im_info)
            else:
                outputs = (image, boxes, masks,
                           question, question_align_matrix,
                           answer_choices, answer_align_matrix,
                           rationale_choices, rationale_align_matrix,
                           im_info)
        elif self.task == 'Q2A':
            if not self.test_mode:
                outputs = (image, boxes, masks,
                           question, question_align_matrix,
                           answer_choices, answer_align_matrix, answer_label,
                           im_info)
            else:
                outputs = (image, boxes, masks,
                           question, question_align_matrix,
                           answer_choices, answer_align_matrix,
                           im_info)
        elif self.task == 'QA2R':
            if not self.test_mode:
                outputs = (image, boxes, masks,
                           ([] if self.qa2r_noq else question) + answer_choices[answer_label],
                           answer_align_matrix[answer_label] if self.qa2r_noq else block_digonal_matrix(question_align_matrix, answer_align_matrix[answer_label]),
                           rationale_choices, rationale_align_matrix, rationale_label,
                           im_info)
            else:
                outputs = (image, boxes, masks,
                           [([] if self.qa2r_noq else question) + a for a in answer_choices],
                           [m if self.qa2r_noq else block_digonal_matrix(question_align_matrix, m)
                            for m in answer_align_matrix],
                           rationale_choices, rationale_align_matrix,
                           im_info)

        return outputs

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path)
        else:
            return Image.open(path)

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)

    @property
    def data_names(self):
        if not self.test_mode:
            if self.task == 'Q2A':
                data_names = ['image', 'boxes', 'masks',
                              'question', 'question_align_matrix',
                              'answer_choices', 'answer_align_matrix', 'answer_label',
                              'im_info']
            elif self.task == 'QA2R':
                data_names = ['image', 'boxes', 'masks',
                              'question', 'question_align_matrix',
                              'rationale_choices', 'rationale_align_matrix', 'rationale_label',
                              'im_info']
            else:
                data_names = ['image', 'boxes', 'masks',
                              'question', 'question_align_matrix',
                              'answer_choices', 'answer_align_matrix', 'answer_label',
                              'rationale_choices', 'rationale_align_matrix', 'rationale_label',
                              'im_info']
        else:
            if self.task == 'Q2A':
                data_names = ['image', 'boxes', 'masks',
                              'question', 'question_align_matrix',
                              'answer_choices', 'answer_align_matrix',
                              'im_info']
            elif self.task == 'QA2R':
                data_names = ['image', 'boxes', 'masks',
                              'question', 'question_align_matrix',
                              'rationale_choices', 'rationale_align_matrix',
                              'im_info']
            else:
                data_names = ['image', 'boxes', 'masks',
                              'question', 'question_align_matrix',
                              'answer_choices', 'answer_align_matrix',
                              'rationale_choices', 'rationale_align_matrix',
                              'im_info']

        return data_names
コード例 #9
0
class Distance_Translation_Multi30kDataset(Dataset):
    def __init__(self,
                 ann_file,
                 image_set,
                 root_path,
                 data_path,
                 seq_len=64,
                 with_precomputed_visual_feat=False,
                 mask_raw_pixels=True,
                 with_rel_task=True,
                 with_mlm_task=False,
                 with_mvrc_task=False,
                 transform=None,
                 test_mode=False,
                 zip_mode=False,
                 cache_mode=False,
                 cache_db=False,
                 ignore_db_cache=True,
                 tokenizer=None,
                 pretrained_model_name=None,
                 add_image_as_a_box=False,
                 aspect_grouping=False,
                 languages_used='first',
                 **kwargs):
        """
        Conceptual Captions Dataset

        :param ann_file: annotation jsonl file
        :param image_set: image folder name, e.g., 'vcr1images'
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to vcr dataset
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(Distance_Translation_Multi30kDataset, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'
        # TODO: need to remove this to allows testing
        # assert not test_mode

        annot = {
            'train': 'train.json',
            'val': 'val.json',
            'test2015': 'test.json'
        }

        self.seq_len = seq_len
        self.with_rel_task = with_rel_task
        self.with_mlm_task = with_mlm_task
        self.with_mvrc_task = with_mvrc_task
        self.data_path = data_path
        self.root_path = root_path
        self.ann_file = os.path.join(data_path, annot[image_set])
        self.with_precomputed_visual_feat = with_precomputed_visual_feat
        self.mask_raw_pixels = mask_raw_pixels
        self.image_set = image_set
        self.transform = transform
        self.test_mode = test_mode
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        #FM edit: added option for how many captions
        self.languages_used = languages_used
        self.cache_dir = os.path.join(root_path, 'cache')
        self.add_image_as_a_box = add_image_as_a_box
        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name,
            cache_dir=self.cache_dir)

        self.zipreader = ZipReader()

        # FM: Customise for multi30k dataset - only used for inference
        self.database = list(jsonlines.open(self.ann_file))
        # if not self.test_mode:
        #     self.database = list(jsonlines.open(self.ann_file))
        # # FM edit: create dataset for test mode
        # else:
        #     self.simple_database = list(jsonlines.open(self.ann_file))
        #     # create database cross-coupling each caption_en with all captions_de
        #     self.database = []
        #     db_index = 0
        #     for x, idb_x in enumerate(self.simple_database):
        #         for y, idb_y in enumerate(self.simple_database):
        #             self.database.append({})
        #             self.database[db_index]['label'] = 1.0 if x==y else 0.0
        #             self.database[db_index]['caption_en'] = self.simple_database[x]['caption_en']
        #             self.database[db_index]['caption_de'] = self.simple_database[y]['caption_de']
        #             self.database[db_index]['caption_en_index'] = x
        #             self.database[db_index]['caption_de_index'] = y
        #             db_index += 1

        if self.aspect_grouping:
            assert False, "not support aspect grouping currently!"
            self.group_ids = self.group_aspect(self.database)

        print('mask_raw_pixels: ', self.mask_raw_pixels)

    @property
    def data_names(self):
        return ['text', 'relationship_label', 'mlm_labels']

    def __getitem__(self, index):
        idb = self.database[index]

        # # indeces for inference
        # caption_en_index = idb['caption_en_index'] if self.test_mode else 0
        # caption_de_index = idb['caption_de_index'] if self.test_mode else 0

        # Task #1: Caption-Image Relationship Prediction
        _p = random.random()
        if not self.test_mode:
            if _p < 0.5:
                relationship_label = 1.0
                caption_en = idb['caption_en']
                caption_de = idb['caption_de']
            else:
                relationship_label = 0.0
                rand_index = random.randrange(0, len(self.database))
                while rand_index == index:
                    rand_index = random.randrange(0, len(self.database))
                # caption_en and image match, german caption is random
                caption_en = idb['caption_en']
                caption_de = self.database[rand_index]['caption_de']
        # for inference
        else:
            relationship_label = 1
            caption_en = idb['caption_en']
            caption_de = idb['caption_de']

        # FM edit: add captions
        caption_tokens_en = self.tokenizer.tokenize(caption_en)
        caption_tokens_de = self.tokenizer.tokenize(caption_de)
        mlm_labels_en = [-1] * len(caption_tokens_en)
        mlm_labels_de = [-1] * len(caption_tokens_de)

        # FM edit: captions of both languages exist in all cases
        if self.languages_used == 'first':
            text_tokens = ['[CLS]'] + caption_tokens_en + ['[SEP]']
            mlm_labels = [-1] + mlm_labels_en + [-1]
        elif self.languages_used == 'second':
            text_tokens = ['[CLS]'] + caption_tokens_de + ['[SEP]']
            mlm_labels = [-1] + mlm_labels_de + [-1]
        else:
            text_tokens = ['[CLS]'] + caption_tokens_en + [
                '[SEP]'
            ] + caption_tokens_de + ['[SEP]']
            mlm_labels = [-1] + mlm_labels_en + [-1] + mlm_labels_de + [-1]

        # convert tokens to ids
        text = self.tokenizer.convert_tokens_to_ids(text_tokens)

        # truncate seq to max len
        if len(text) > self.seq_len:
            text_len_keep = len(text)
            while (text_len_keep) > self.seq_len and (text_len_keep > 0):
                text_len_keep -= 1
            if text_len_keep < 2:
                text_len_keep = 2
            text = text[:(text_len_keep - 1)] + [text[-1]]

        return text, relationship_label, mlm_labels

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)
コード例 #10
0
class Foil(Dataset):
    def __init__(self,
                 root_path,
                 data_path,
                 boxes='gt',
                 proposal_source='official',
                 transform=None,
                 test_mode=False,
                 zip_mode=False,
                 cache_mode=False,
                 cache_db=False,
                 ignore_db_cache=True,
                 tokenizer=None,
                 pretrained_model_name=None,
                 add_image_as_a_box=False,
                 mask_size=(14, 14),
                 aspect_grouping=False,
                 **kwargs):
        """
        Foil Dataset

        :param image_set: image folder name
        :param root_path: root path to cache database loaded from annotation file
        :param data_path: path to dataset
        :param boxes: boxes to use, 'gt' or 'proposal'
        :param transform: transform
        :param test_mode: test mode means no labels available
        :param zip_mode: reading images and metadata in zip archive
        :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
        :param ignore_db_cache: ignore previous cached database, reload it from annotation file
        :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
        :param add_image_as_a_box: add whole image as a box
        :param mask_size: size of instance mask of each object
        :param aspect_grouping: whether to group images via their aspect
        :param kwargs:
        """
        super(Foil, self).__init__()

        assert not cache_mode, 'currently not support cache mode!'

        coco_annot_files = {
            "train2014": "annotations/instances_train2014.json",
            "val2014": "annotations/instances_val2014.json",
            "test2015": "annotations/image_info_test2015.json",
        }

        foil_annot_files = {
            "train": "foil/foilv1.0_train_2017.json",
            "test": "foil/foilv1.0_test_2017.json"
        }

        foil_vocab_file = "foil/vocab.txt"

        self.vg_proposal = ("vgbua_res101_precomputed",
                            "trainval2014_resnet101_faster_rcnn_genome")

        self.test_mode = test_mode
        self.data_path = data_path
        self.root_path = root_path
        self.transform = transform

        vocab_file = open(os.path.join(data_path, foil_vocab_file), 'r')
        vocab_lines = vocab_file.readlines()
        vocab_lines = [v.strip() for v in vocab_lines]
        self.itos = vocab_lines
        self.stoi = dict(list(zip(self.itos, range(len(vocab_lines)))))

        if self.test_mode:
            self.image_set = "val2014"
            coco_annot_file = coco_annot_files["val2014"]
        else:
            self.image_set = "train2014"
            coco_annot_file = coco_annot_files["train2014"]

        self.coco = COCO(
            annotation_file=os.path.join(data_path, coco_annot_file))
        self.foil = FOIL(data_path, 'train' if not test_mode else 'test')
        self.foil_ids = list(self.foil.Foils.keys())
        self.foils = self.foil.loadFoils(foil_ids=self.foil_ids)
        if 'proposal' in boxes:
            with open(os.path.join(data_path, proposal_dets), 'r') as f:
                proposal_list = json.load(f)
            self.proposals = {}
            for proposal in proposal_list:
                image_id = proposal['image_id']
                if image_id in self.proposals:
                    self.proposals[image_id].append(proposal['box'])
                else:
                    self.proposals[image_id] = [proposal['box']]
        self.boxes = boxes
        self.zip_mode = zip_mode
        self.cache_mode = cache_mode
        self.cache_db = cache_db
        self.ignore_db_cache = ignore_db_cache
        self.aspect_grouping = aspect_grouping
        self.cache_dir = os.path.join(root_path, 'cache')
        self.add_image_as_a_box = add_image_as_a_box
        self.mask_size = mask_size
        if not os.path.exists(self.cache_dir):
            makedirsExist(self.cache_dir)
        self.tokenizer = tokenizer if tokenizer is not None \
            else BertTokenizer.from_pretrained(
            'bert-base-uncased' if pretrained_model_name is None else pretrained_model_name,
            cache_dir=self.cache_dir)

        if zip_mode:
            self.zipreader = ZipReader()

        self.database = self.load_annotations()
        if self.aspect_grouping:
            self.group_ids = self.group_aspect(self.database)

    @property
    def data_names(self):
        return [
            'image', 'boxes', 'im_info', 'expression', 'label', 'pos',
            'target', 'mask'
        ]

    def __getitem__(self, index):
        idb = self.database[index]

        # image related
        img_id = idb['image_id']
        image = self._load_image(idb['image_fn'])
        im_info = torch.as_tensor([idb['width'], idb['height'], 1.0, 1.0])
        #if not self.test_mode:
        #    gt_box = torch.as_tensor(idb['gt_box'])
        flipped = False
        if self.boxes == 'gt':
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            boxes = []
            for ann in anns:
                x_, y_, w_, h_ = ann['bbox']
                boxes.append([x_, y_, x_ + w_, y_ + h_])
            boxes = torch.as_tensor(boxes)
        elif self.boxes == 'proposal':
            if self.proposal_source == 'official':
                boxes = torch.as_tensor(self.proposals[img_id])
                boxes[:, [2, 3]] += boxes[:, [0, 1]]
            elif self.proposal_source == 'vg':
                box_file = os.path.join(
                    self.data_path, self.vg_proposal[0],
                    '{0}.zip@/{0}'.format(self.vg_proposal[1]))
                boxes_fn = os.path.join(box_file,
                                        '{}.json'.format(idb['image_id']))
                boxes_data = self._load_json(boxes_fn)
                boxes = torch.as_tensor(
                    np.frombuffer(self.b64_decode(boxes_data['boxes']),
                                  dtype=np.float32).reshape(
                                      (boxes_data['num_boxes'], -1)))
            else:
                raise NotImplemented
        elif self.boxes == 'proposal+gt' or self.boxes == 'gt+proposal':
            if self.proposal_source == 'official':
                boxes = torch.as_tensor(self.proposals[img_id])
                boxes[:, [2, 3]] += boxes[:, [0, 1]]
            elif self.proposal_source == 'vg':
                box_file = os.path.join(
                    self.data_path, self.vg_proposal[0],
                    '{0}.zip@/{0}'.format(self.vg_proposal[1]))
                boxes_fn = os.path.join(box_file,
                                        '{}.json'.format(idb['image_id']))
                boxes_data = self._load_json(boxes_fn)
                boxes = torch.as_tensor(
                    np.frombuffer(self.b64_decode(boxes_data['boxes']),
                                  dtype=np.float32).reshape(
                                      (boxes_data['num_boxes'], -1)))
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            gt_boxes = []
            for ann in anns:
                x_, y_, w_, h_ = ann['bbox']
                gt_boxes.append([x_, y_, x_ + w_, y_ + h_])
            gt_boxes = torch.as_tensor(gt_boxes)
            boxes = torch.cat((boxes, gt_boxes), 0)
        else:
            raise NotImplemented

        if self.add_image_as_a_box:
            w0, h0 = im_info[0], im_info[1]
            image_box = torch.as_tensor([[0.0, 0.0, w0 - 1, h0 - 1]])
            boxes = torch.cat((image_box, boxes), dim=0)

        if self.transform is not None:
            image, boxes, _, im_info, flipped = self.transform(
                image, boxes, None, im_info, flipped)

        # clamp boxes
        w = im_info[0].item()
        h = im_info[1].item()
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(min=0, max=w - 1)
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(min=0, max=h - 1)

        # assign label expression with the foil annotation
        label = idb['label']
        foil_pos = idb['pos']

        # expression
        exp = idb['caption_tokens']
        exp_ids = self.tokenizer.convert_tokens_to_ids(exp)

        target = self.stoi[idb['target_word']]
        mask = idb['mask']
        if self.test_mode:
            return image, boxes, im_info, exp_ids, label, foil_pos, target, mask
        else:
            return image, boxes, im_info, exp_ids, label, foil_pos, target, mask

    @staticmethod
    def b64_decode(string):
        return base64.decodebytes(string.encode())

    def load_annotations(self):
        tic = time.time()
        database = []
        db_cache_name = 'foil_{}'.format(self.image_set)
        if self.zip_mode:
            db_cache_name = db_cache_name + '_zipmode'
        if self.test_mode:
            db_cache_name = db_cache_name + '_testmode'
        db_cache_root = os.path.join(self.root_path, 'cache')
        db_cache_path = os.path.join(db_cache_root,
                                     '{}.pkl'.format(db_cache_name))

        if os.path.exists(db_cache_path):
            if not self.ignore_db_cache:
                # reading cached database
                print('cached database found in {}.'.format(db_cache_path))
                with open(db_cache_path, 'rb') as f:
                    print('loading cached database from {}...'.format(
                        db_cache_path))
                    tic = time.time()
                    database = cPickle.load(f)
                    print('Done (t={:.2f}s)'.format(time.time() - tic))
                    return database
            else:
                print('cached database ignored.')

        # ignore or not find cached database, reload it from annotation file
        print('loading database of split {}...'.format(self.image_set))
        tic = time.time()

        for foil_id, foil in zip(self.foil_ids, self.foils):
            iset = 'train2014'
            if self.zip_mode:
                image_fn = os.path.join(
                    self.data_path, iset + '.zip@/' + iset,
                    'COCO_{}_{:012d}.jpg'.format(iset, foil['image_id']))
            else:
                image_fn = os.path.join(
                    self.root_path, self.data_path, iset,
                    'COCO_{}_{:012d}.jpg'.format(iset, foil['image_id']))

            expression_tokens = self.tokenizer.basic_tokenizer.tokenize(
                foil['caption'])
            expression_wps = []
            for token in expression_tokens:
                expression_wps.extend(
                    self.tokenizer.wordpiece_tokenizer.tokenize(token))

            word_offsets = [0]

            for i, wp in enumerate(expression_wps):
                if wp[0] == '#':
                    #still inside single word
                    continue
                else:
                    #this is the beginning of a new word
                    word_offsets.append(i)

            word_offsets.append(len(expression_wps))

            target_word = foil['target_word']
            foil_word = foil['foil_word']
            target_wps = None
            target_pos = -1
            if foil['foil']:
                foil_wps = self.tokenizer.wordpiece_tokenizer.tokenize(
                    foil_word)
                twps_len = len(foil_wps)
                for i in range(len(expression_wps) - twps_len):
                    if expression_wps[i:i + twps_len] == foil_wps:
                        target_pos = i
                        break
            else:
                twps_len = 1
            idb = {
                'ann_id': foil['id'],
                'foil_id': foil['foil_id'],
                'image_id': foil['image_id'],
                'image_fn': image_fn,
                'width': self.coco.imgs[foil['image_id']]['width'],
                'height': self.coco.imgs[foil['image_id']]['height'],
                'caption': foil['caption'].strip(),
                'caption_tokens': expression_wps,
                'target_word': foil['target_word'],
                'target': self.stoi.get(foil['target_word'], 0),
                'foil_word': foil['foil_word'],
                'label': foil['foil'],
                'pos': target_pos,
                'mask': twps_len
            }
            database.append(idb)

        print('Done (t={:.2f}s)'.format(time.time() - tic))

        # cache database via cPickle
        if self.cache_db:
            print('caching database to {}...'.format(db_cache_path))
            tic = time.time()
            if not os.path.exists(db_cache_root):
                makedirsExist(db_cache_root)
            with open(db_cache_path, 'wb') as f:
                cPickle.dump(database, f)
            print('Done (t={:.2f}s)'.format(time.time() - tic))

        return database

    @staticmethod
    def group_aspect(database):
        print('grouping aspect...')
        t = time.time()

        # get shape of all images
        widths = torch.as_tensor([idb['width'] for idb in database])
        heights = torch.as_tensor([idb['height'] for idb in database])

        # group
        group_ids = torch.zeros(len(database))
        horz = widths >= heights
        vert = 1 - horz
        group_ids[horz] = 0
        group_ids[vert] = 1

        print('Done (t={:.2f}s)'.format(time.time() - t))

        return group_ids

    def __len__(self):
        return len(self.database)

    def _load_image(self, path):
        if '.zip@' in path:
            return self.zipreader.imread(path).convert('RGB')
        else:
            return Image.open(path).convert('RGB')

    def _load_json(self, path):
        if '.zip@' in path:
            f = self.zipreader.read(path)
            return json.loads(f.decode())
        else:
            with open(path, 'r') as f:
                return json.load(f)