Beispiel #1
0
    def load_ann_fields(self, ann_reader):
        fields = list()
        fields.extend(ann_reader.fields)
        if self.split == 'test':
            return fields

        def depend_fn(a_scores):
            a_label_score = torch.empty(len(
                self.answer_vocab)).fill_(0).float()
            for a_token, a_score in a_scores:
                a_label_score[self.answer_vocab[a_token]] = a_score
            return a_label_score

        fields.append(
            ProxyField('a_label_scores',
                       ann_reader['a_scores'],
                       depend_fn=depend_fn))

        def depend_fn(answers):
            a_label_count = torch.empty(len(
                self.answer_vocab)).fill_(0).float()
            for a_token, a_count in answers:
                a_label_count[self.answer_vocab[a_token]] = a_count
            return a_label_count

        fields.append(
            ProxyField('a_label_counts',
                       ann_reader['a_counts'],
                       depend_fn=depend_fn))

        return fields
Beispiel #2
0
    def build_fields(self):
        ann_reader = DictReader(
            self.load_combined_anns(self.data_dir, self.split))
        img_id_field = ann_reader['img_ids']
        fields = list()
        fields.extend([
            ann_reader[name] for name in ('q_ids', 'q_tokens', 'img_ids',
                                          'q_labels', 'q_lens', 'a_labels')
        ])

        if self.req_field_names is None or 'img_obj_feats' in self.req_field_names:
            obj_infos = json.load(
                self.data_dir.joinpath('objects/gqa_objects_info.json').open())
            h5_files = self.data_dir.joinpath('objects').glob(
                'gqa_objects_*.h5')
            h5_readers = {
                h5_file.name: H5Reader(h5_file)
                for h5_file in h5_files
            }

            def idx_map_fn(dset_id):
                img_id = img_id_field[dset_id]
                obj_info = obj_infos[img_id]
                return obj_info['idx']

            for reader in h5_readers.values():
                for field in reader.fields:
                    field.idx_map_fn = idx_map_fn

            def switch_fn(dset_id, fields):
                img_id = img_id_field[dset_id]
                obj_info = obj_infos[img_id]
                return fields[f'gqa_objects_{obj_info["file"]}.h5']

            def depend_fn(x):
                return torch.from_numpy(x)

            feat_fields = {
                key: ProxyField('features', reader['features'], depend_fn)
                for key, reader in h5_readers.items()
            }

            def depend_fn(box, img_shape):
                box = torch.from_numpy(box)
                return box / torch.cat((img_shape, img_shape))

            box_fields = {
                key: ProxyField('boxes',
                                (reader['bboxes'], ann_reader['img_shapes']),
                                depend_fn)
                for key, reader in h5_readers.items()
            }

            fields.append(SwitchField('img_obj_feats', feat_fields, switch_fn))
            fields.append(SwitchField('img_obj_boxes', box_fields, switch_fn))
            fields.append(ann_reader['img_obj_nums'])

        return fields
Beispiel #3
0
    def load_obj_fields(self, img_ids_field, img_shapes_field):
        data_dir = self.data_dir.joinpath(self.name)
        if self.req_field_names is not None and utils.not_in(
            ('img_obj_feats', 'img_box_feats'), self.req_field_names):
            return None

        if self.split != 'test':
            obj_feats = zarr.open(
                data_dir.joinpath('trainval.zarr').as_posix(), mode='r')
            box_feats = zarr.open(
                data_dir.joinpath('trainval_boxes.zarr').as_posix(), mode='r')
        else:
            obj_feats = zarr.open(data_dir.joinpath('test.zarr').as_posix(),
                                  mode='r')
            box_feats = zarr.open(
                data_dir.joinpath('test_boxes.zarr').as_posix(), mode='r')

        obj_field = PseudoField(
            'img_obj_feats', obj_feats,
            lambda dset_id: str(img_ids_field[dset_id].item()))
        obj_field = ProxyField('img_obj_feats', obj_field,
                               lambda x: torch.from_numpy(np.asarray(x)))
        box_field = PseudoField(
            'img_box_feats', box_feats,
            lambda dset_id: str(img_ids_field[dset_id].item()))

        def depend_fn(box, img_shape):
            box = torch.from_numpy(np.asarray(box))
            img_shape = img_shape.unsqueeze(0)
            img_shape = torch.cat((img_shape, img_shape), dim=-1).float()
            box = box / img_shape
            return box

        box_field = ProxyField('img_box_feats', (box_field, img_shapes_field),
                               depend_fn)

        def depend_fn(obj, box):
            return torch.cat((obj, box), dim=-1)

        obj_field = ProxyField('img_obj_feats', (obj_field, box_field),
                               depend_fn)
        return [obj_field, box_field]
Beispiel #4
0
 def build_fields(self):
     if self.req_field_names is not None:
         assert all([name in self.valid_field_names for name in self.req_field_names])
     fields = []
     img_files = self._get_image_files()
     img_file_field = Field('img_files', datas=img_files)
     fields.append(img_file_field)
     if self.req_field_names is None or 'img_ids' in self.req_field_names:
         img_idxs = [self.file2idx_fn(img_file) for img_file in img_files]
         fields.append(Field('img_ids', datas=img_idxs))
     if self.req_field_names is None or 'images' in self.req_field_names:
         fields.append(ProxyField('images', img_file_field, self._load_image))
     return fields
Beispiel #5
0
    def build_fields(self):
        data_dir = self.data_dir.joinpath(self.name)
        ann_reader = DictReader(
            self.load_combined_anns(self.data_dir, self.split))
        img_id_field = ann_reader['img_ids']
        fields = list()
        fields.extend([
            ann_reader[name] for name in ('q_ids', 'q_tokens', 'img_ids',
                                          'img_shapes', 'a_counts')
        ])
        split_fields = ann_reader['splits']

        if self.req_field_names is None or 'img_obj_feats' in self.req_field_names:
            # add image object fields
            img_readers = {
                split: H5Reader(data_dir.joinpath(f'{split}_obj_feat_36.hdf5'))
                for split in self.split.split('_')
            }
            hd5_imgId2Idxs = {
                split:
                {img_id: idx
                 for idx, img_id in enumerate(reader['img_ids'])}
                for split, reader in img_readers.items()
            }

            img_box_fields = {}
            bb_feature_fields = {}
            for split, img_reader in img_readers.items():

                def id_map_fn(dset_id):
                    img_id = img_id_field[dset_id]
                    return int(
                        hd5_imgId2Idxs[split_fields[dset_id]][int(img_id)])

                for field in img_reader.fields:
                    field.idx_map_fn = id_map_fn

                def depend_fn(bb_locs, img_shape):
                    bb_locs = torch.from_numpy(bb_locs)
                    new_bb_locs = torch.empty_like(bb_locs)
                    new_bb_locs[:, 0::2] = bb_locs[:, 0::2] / img_shape[0]
                    new_bb_locs[:, 1::2] = bb_locs[:, 1::2] / img_shape[1]
                    return new_bb_locs

                img_box_field = ProxyField(
                    'img_boxes',
                    (img_reader['img_boxes'], ann_reader['img_shapes']),
                    depend_fn)
                img_box_fields[split] = img_box_field

                # add image obj features field

                def depend_fn(bb_feat, bb_locs):
                    bb_feat = torch.from_numpy(bb_feat)
                    return torch.cat((bb_feat, bb_locs), dim=-1)

                bb_feature_field = ProxyField(
                    'img_obj_feats',
                    (img_reader['img_obj_feats'], img_box_field), depend_fn)
                bb_feature_fields[split] = bb_feature_field

            def switch_fn(idx, fields):
                split = split_fields[idx]
                return fields[split]

            fields.append(SwitchField('img_boxes', img_box_fields, switch_fn))
            fields.append(
                SwitchField('img_obj_feats', bb_feature_fields, switch_fn))

        if self.req_field_names is None or 'q_labels' in self.req_field_names:
            q_label_field = ProxyField('q_labels',
                                       ann_reader['q_labels'],
                                       depend_fn=lambda x: x[:self.seq_len])
            q_len_field = ProxyField(
                'q_lens',
                ann_reader['q_lens'],
                depend_fn=lambda x: min(x, torch.tensor(self.seq_len)))
            fields.extend([q_label_field, q_len_field])

        if self.req_field_names is None or 'a_label_scores' in self.req_field_names and self.split != 'test':

            def depend_fn(a_scores):
                a_label_score = torch.empty(len(self.answer_vocab)).fill_(
                    self.answer_vocab.padding_idx).float()
                for a_token, a_score in a_scores:
                    a_label_score[self.answer_vocab[a_token]] = a_score
                return a_label_score

            a_label_field = ProxyField('a_label_scores',
                                       ann_reader['a_scores'],
                                       depend_fn=depend_fn)
            fields.append(a_label_field)

        if self.req_field_names is None or 'a_label_counts' in self.req_field_names and self.split != 'test':

            def depend_fn(answers):
                a_label_count = torch.empty(len(self.answer_vocab)).fill_(
                    self.answer_vocab.padding_idx).float()
                for a_token, a_count in answers:
                    a_label_count[self.answer_vocab[a_token]] = a_count
                return a_label_count

            a_count_field = ProxyField('a_label_counts',
                                       ann_reader['a_counts'],
                                       depend_fn=depend_fn)
            fields.append(a_count_field)

        if self.req_field_names is None or 'q_tokens' in self.req_field_names:
            fields.append(ann_reader['q_tokens'])
        if self.req_field_names is None or 'img_ids' in self.req_field_names:
            fields.append(ann_reader['img_ids'])
        return fields