示例#1
0
def main():
    args = parser.parse_args()
    # phases to be processed.
    phases = [phase.strip() for phase in args.phases.split(',')]

    # annotation files to be processed
    if sorted(phases) == sorted(['train', 'val', 'test'
                                 ]) and args.ann_files == '':
        tmplt = 'data/annotations/captions_%s2017.json'
        ann_files = [tmplt % 'train', tmplt % 'val', '']
    else:
        ann_files = [
            ann_file.strip() for ann_file in args.ann_files.split(',')
        ]

    # batch size for extracting feature vectors.
    batch_size = args.batch_size

    # maximum length of caption(number of word). if caption is longer than max_length, deleted.
    max_length = args.max_length

    # if word occurs less than word_count_threshold in training dataset, the word index is special unknown token.
    word_count_threshold = args.word_count_threshold
    vocab_size = args.vocab_size

    for phase, ann_file in zip(phases, ann_files):
        _process_caption_data(phase, ann_file=ann_file, max_length=max_length)

        if phase == 'train':
            captions_data = load_json('./data/train/captions_train2017.json')

            word_to_idx = _build_vocab(captions_data,
                                       threshold=word_count_threshold,
                                       vocab_size=vocab_size)
            save_json(word_to_idx, './data/word_to_idx.json')

            new_captions_data = _build_caption_vector(captions_data,
                                                      word_to_idx=word_to_idx,
                                                      max_length=max_length)
            save_json(new_captions_data, ann_file)

    print('Finished processing caption data')

    feature_extractor = FeatureExtractor(model_name='resnet101', layer=3)
    for phase in phases:
        if not os.path.isdir('./data/%s/feats/' % phase):
            os.makedirs('./data/%s/feats/' % phase)

        image_paths = os.listdir('./image/%s/' % phase)
        dataset = CocoImageDataset(root='./image/%s/' % phase,
                                   image_paths=image_paths)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=batch_size,
                                                  num_workers=8)

        for batch_paths, batch_images in tqdm(data_loader):
            feats = feature_extractor(batch_images).data.cpu().numpy()
            feats = feats.reshape(-1, feats.shape[1] * feats.shape[2],
                                  feats.shape[-1])
            for j in range(len(feats)):
                np.save('./data/%s/feats/%s.npy' % (phase, batch_paths[j]),
                        feats[j])