def main(): args = parser.parse_args() # phases to be processed. phases = [phase.strip() for phase in args.phases.split(',')] # annotation files to be processed if sorted(phases) == sorted(['train', 'val', 'test' ]) and args.ann_files == '': tmplt = 'data/annotations/captions_%s2017.json' ann_files = [tmplt % 'train', tmplt % 'val', ''] else: ann_files = [ ann_file.strip() for ann_file in args.ann_files.split(',') ] # batch size for extracting feature vectors. batch_size = args.batch_size # maximum length of caption(number of word). if caption is longer than max_length, deleted. max_length = args.max_length # if word occurs less than word_count_threshold in training dataset, the word index is special unknown token. word_count_threshold = args.word_count_threshold vocab_size = args.vocab_size for phase, ann_file in zip(phases, ann_files): _process_caption_data(phase, ann_file=ann_file, max_length=max_length) if phase == 'train': captions_data = load_json('./data/train/captions_train2017.json') word_to_idx = _build_vocab(captions_data, threshold=word_count_threshold, vocab_size=vocab_size) save_json(word_to_idx, './data/word_to_idx.json') new_captions_data = _build_caption_vector(captions_data, word_to_idx=word_to_idx, max_length=max_length) save_json(new_captions_data, ann_file) print('Finished processing caption data') feature_extractor = FeatureExtractor(model_name='resnet101', layer=3) for phase in phases: if not os.path.isdir('./data/%s/feats/' % phase): os.makedirs('./data/%s/feats/' % phase) image_paths = os.listdir('./image/%s/' % phase) dataset = CocoImageDataset(root='./image/%s/' % phase, image_paths=image_paths) data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8) for batch_paths, batch_images in tqdm(data_loader): feats = feature_extractor(batch_images).data.cpu().numpy() feats = feats.reshape(-1, feats.shape[1] * feats.shape[2], feats.shape[-1]) for j in range(len(feats)): np.save('./data/%s/feats/%s.npy' % (phase, batch_paths[j]), feats[j])