Beispiel #1
0
    def __init__(self,
                 imdb_file,
                 image_feat_directories,
                 verbose=False,
                 **data_params):
        super(vqa_dataset, self).__init__()
        if imdb_file.endswith('.npy'):
            imdb = np.load(imdb_file)
        else:
            raise TypeError('unknown imdb format.')
        self.verbose = verbose
        self.imdb = imdb
        self.image_feat_directories = image_feat_directories
        self.data_params = data_params
        self.cp = data_params['cp']
        self.image_depth_first = data_params['image_depth_first']
        self.image_max_loc = (data_params['image_max_loc']
                              if 'image_max_loc' in data_params else None)
        self.vocab_dict = text_processing.VocabDict(
            data_params['vocab_question_file'])
        self.T_encoder = data_params['T_encoder']

        # read the header of imdb file
        header_idx = 0
        self.first_element_idx = 1
        header = self.imdb[header_idx]
        self.load_answer = header['has_answer']
        self.load_gt_layout = header['has_gt_layout']
        self.load_gt_layout = False
        data_version = header['version']
        if data_version != imdb_version:
            print("observed imdb_version is", data_version,
                  "expected imdb version is", imdb_version)
            raise TypeError('imdb version do not match.')

        if 'load_gt_layout' in data_params:
            self.load_gt_layout = data_params['load_gt_layout']
        # the answer dict is always loaded, regardless of self.load_answer
        self.answer_dict = text_processing.VocabDict(
            data_params['vocab_answer_file'])

        if self.load_gt_layout:
            self.T_decoder = data_params['T_decoder']
            self.assembler = data_params['assembler']
            self.prune_filter_module = (data_params['prune_filter_module']
                                        if 'prune_filter_module' in data_params
                                        else False)
        else:
            print('imdb does not contain ground-truth layout')
            print('Loading model and config ...')

        # load one feature map to peek its size
        self.image_feat_readers = []
        for image_dir in self.image_feat_directories:
            image_file_name = os.path.basename(
                self.imdb[self.first_element_idx]['feature_path'])
            if self.cp:
                image_feat_path = get_cp_feat_path(self.image_feat_directories,
                                                   image_file_name)
            else:
                image_feat_path = os.path.join(image_dir, image_file_name)
            feats = np.load(image_feat_path)
            self.image_feat_readers.append(
                get_image_feat_reader(feats.ndim, self.image_depth_first,
                                      feats, self.image_max_loc))

        self.fastRead = False
        self.testMode = False
        if data_params['test_mode']:
            self.testMode = True
        if data_params['fastRead']:
            self.fastRead = True
            self.featDict = {}
            image_count = 0
            image_dir0 = self.image_feat_directories[0]
            for feat_file in tqdm(os.listdir(image_dir0)):
                if feat_file.endswith("npy"):
                    image_feats = read_in_image_feats(
                        self.image_feat_directories, self.image_feat_readers,
                        feat_file, self.cp)
                    self.featDict[feat_file] = image_feats
                    image_count += 1
            print("load %d images" % image_count)
    parser.add_argument("--data_dir",
                        type=str,
                        required=True,
                        help="data directory")
    parser.add_argument("--out_dir",
                        type=str,
                        required=True,
                        help="imdb output directory")
    args = parser.parse_args()

    data_dir = args.data_dir
    out_dir = args.out_dir

    #vocab_answer_file = os.path.join(out_dir, 'answers_vqa.txt')
    vocab_answer_file = '/home1/BTP/pg_aa_1/btp/data/answers_vqa_larger.txt'
    answer_dict = text_processing.VocabDict(vocab_answer_file)
    valid_answer_set = set(answer_dict.word_list)

    # imdb_train2014 = build_imdb('train2014', valid_answer_set)
    imdb_val2014 = build_imdb('val2014', valid_answer_set)
    # imdb_test2015 = build_imdb('test2015', valid_answer_set)

    imdb_dir = os.path.join(out_dir, 'imdb')
    os.makedirs(imdb_dir, exist_ok=True)
    #np.save(os.path.join(imdb_dir, 'imdb_train2014.npy'),
    #        np.array(imdb_train2014))
    np.save(os.path.join(imdb_dir, 'imdb_val2014.npy'), np.array(imdb_val2014))
    # np.save(os.path.join(imdb_dir, 'imdb_test2015.npy'),
    #        np.array(imdb_test2015))
    '''imdb_minival2014 = build_imdb('minival2014',
                                  valid_answer_set,