def prepare_eval_data(config): """ Prepare the data for evaluating the model. """ coco = COCO(config.eval_caption_file, config.max_eval_ann_num) image_ids = [] image_files = [] if not config.max_eval_ann_num: print('No config.max_eval_ann_num') image_ids = list(coco.imgs.keys()) image_files = [ os.path.join(config.eval_image_dir, coco.imgs[image_id]['file_name']) for image_id in image_ids ] else: print('config.max_eval_ann_num=', config.max_eval_ann_num) image_ids = [ coco.anns[ann_id]['image_id'] for ann_id in islice(coco.anns, 0, config.max_eval_ann_num) ] image_files = [ os.path.join(config.eval_image_dir, coco.imgs[image_id]['file_name']) for image_id in islice(image_ids, 0, config.max_eval_ann_num) ] print("Building the vocabulary...") if os.path.exists(config.vocabulary_file): vocabulary = Vocabulary(config.vocabulary_size, config.vocabulary_file) else: vocabulary = build_vocabulary(config) print("Vocabulary built.") print("Number of words = %d" % (vocabulary.size)) print('Download Images') coco.download(config.eval_image_dir, image_ids) print('Finished download images') print("Building the dataset...") dataset = DataSet(image_ids, image_files, config.batch_size) print("Dataset built.") return coco, dataset, vocabulary
def prepare_train_data(config): """ Prepare the data for training the model. """ coco = COCO(config.train_caption_file, config.max_train_ann_num) coco.filter_by_cap_len(config.max_caption_length) print("Building the vocabulary...") vocabulary = Vocabulary(config.vocabulary_size) if not os.path.exists(config.vocabulary_file): if not config.max_train_ann_num: vocabulary.build(coco.all_captions()) else: vocabulary.build((coco.all_captions())[:config.max_train_ann_num]) vocabulary.save(config.vocabulary_file) else: vocabulary.load(config.vocabulary_file) print("Vocabulary built.") print("Number of words = %d" % (vocabulary.size)) coco.filter_by_words(set(vocabulary.words)) print("Processing the captions...") captions = [] image_ids = [] image_files = [] if not os.path.exists(config.temp_annotation_file): if not config.max_train_ann_num: print('No config.max_train_ann_num') captions = [coco.anns[ann_id]['caption'] for ann_id in coco.anns] image_ids = [coco.anns[ann_id]['image_id'] for ann_id in coco.anns] image_files = [ os.path.join(config.train_image_dir, coco.imgs[image_id]['file_name']) for image_id in image_ids ] else: print('config.max_train_ann_num=', config.max_train_ann_num) captions = [ coco.anns[ann_id]['caption'] for ann_id in islice(coco.anns, 0, config.max_train_ann_num) ] image_ids = [ coco.anns[ann_id]['image_id'] for ann_id in islice(coco.anns, 0, config.max_train_ann_num) ] image_files = [ os.path.join(config.train_image_dir, coco.imgs[image_id]['file_name']) for image_id in islice(image_ids, 0, config.max_train_ann_num) ] annotations = pd.DataFrame({ 'image_id': image_ids, 'image_file': image_files, 'caption': captions }) annotations.to_csv(config.temp_annotation_file) else: if not config.max_train_ann_num: print('No config.max_train_ann_num') annotations = pd.read_csv(config.temp_annotation_file) captions = annotations['caption'].values image_ids = annotations['image_id'].values image_files = annotations['image_file'].values else: print('config.max_train_ann_num=', config.max_train_ann_num) annotations = pd.read_csv(config.temp_annotation_file) captions = annotations['caption'].values[:config.max_train_ann_num] image_ids = annotations['image_id'].values[:config. max_train_ann_num] image_files = annotations['image_file'].values[:config. max_train_ann_num] if not os.path.exists(config.temp_data_file): word_idxs = [] masks = [] for caption in tqdm(captions): current_word_idxs_ = vocabulary.process_sentence(caption) current_num_words = len(current_word_idxs_) current_word_idxs = np.zeros(config.max_caption_length, dtype=np.int32) current_masks = np.zeros(config.max_caption_length) current_word_idxs[:current_num_words] = np.array( current_word_idxs_) current_masks[:current_num_words] = 1.0 word_idxs.append(current_word_idxs) masks.append(current_masks) word_idxs = np.array(word_idxs) masks = np.array(masks) data = {'word_idxs': word_idxs, 'masks': masks} np.save(config.temp_data_file, data) else: data = np.load(config.temp_data_file, encoding='latin1').item() word_idxs = data['word_idxs'] masks = data['masks'] print("Captions processed.") print("Number of captions = %d" % (len(captions))) print('Download Images') coco.download(config.train_image_dir, image_ids) print('Finished download images') print("Building the dataset...") dataset = DataSet(image_ids, image_files, config.batch_size, word_idxs, masks, True, True) print("Dataset built.") return dataset