def main(): opt = parse_args() print(json.dumps(vars(opt), indent=2)) rootpath = opt.rootpath trainCollection = opt.trainCollection valCollection = opt.valCollection testCollection = opt.testCollection if opt.loss_fun == "mrl" and opt.measure == "cosine": assert opt.text_norm is True assert opt.visual_norm is True # checkpoint path model_info = '%s_concate_%s_dp_%.1f_measure_%s' % ( opt.model, opt.concate, opt.dropout, opt.measure) # text-side multi-level encoding info text_encode_info = 'vocab_%s_word_dim_%s_text_rnn_size_%s_text_norm_%s' % \ (opt.vocab, opt.word_dim, opt.text_rnn_size, opt.text_norm) text_encode_info += "_kernel_sizes_%s_num_%s" % (opt.text_kernel_sizes, opt.text_kernel_num) # video-side multi-level encoding info visual_encode_info = 'visual_feature_%s_visual_rnn_size_%d_visual_norm_%s' % \ (opt.visual_feature, opt.visual_rnn_size, opt.visual_norm) visual_encode_info += "_kernel_sizes_%s_num_%s" % (opt.visual_kernel_sizes, opt.visual_kernel_num) # common space learning info mapping_info = "mapping_text_%s_img_%s" % (opt.text_mapping_layers, opt.visual_mapping_layers) loss_info = 'loss_func_%s_margin_%s_direction_%s_max_violation_%s_cost_style_%s' % \ (opt.loss_fun, opt.margin, opt.direction, opt.max_violation, opt.cost_style) optimizer_info = 'optimizer_%s_lr_%s_decay_%.2f_grad_clip_%.1f_val_metric_%s' % \ (opt.optimizer, opt.learning_rate, opt.lr_decay_rate, opt.grad_clip, opt.val_metric) opt.logger_name = os.path.join(rootpath, trainCollection, opt.cv_name, valCollection, model_info, text_encode_info, visual_encode_info, mapping_info, loss_info, optimizer_info, opt.postfix) print(opt.logger_name) if checkToSkip(os.path.join(opt.logger_name, 'model_best.pth.tar'), opt.overwrite): sys.exit(0) if checkToSkip(os.path.join(opt.logger_name, 'val_metric.txt'), opt.overwrite): sys.exit(0) makedirsforfile(os.path.join(opt.logger_name, 'val_metric.txt')) logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) tb_logger.configure(opt.logger_name, flush_secs=5) opt.text_kernel_sizes = map(int, opt.text_kernel_sizes.split('-')) opt.visual_kernel_sizes = map(int, opt.visual_kernel_sizes.split('-')) # collections: trian, val collections = {'train': trainCollection, 'val': valCollection} cap_file = { 'train': '%s.caption.txt' % trainCollection, 'val': '%s.caption.txt' % valCollection } # caption caption_files = { x: os.path.join(rootpath, collections[x], 'TextData', cap_file[x]) for x in collections } # Load visual features visual_feat_path = { x: os.path.join(rootpath, collections[x], 'FeatureData', opt.visual_feature) for x in collections } visual_feats = {x: BigFile(visual_feat_path[x]) for x in visual_feat_path} opt.visual_feat_dim = visual_feats['train'].ndims # set bow vocabulary and encoding bow_vocab_file = os.path.join(rootpath, opt.trainCollection, 'TextData', 'vocabulary', 'bow', opt.vocab + '.pkl') bow_vocab = pickle.load(open(bow_vocab_file, 'rb')) bow2vec = get_text_encoder('bow')(bow_vocab) opt.bow_vocab_size = len(bow_vocab) # set rnn vocabulary rnn_vocab_file = os.path.join(rootpath, opt.trainCollection, 'TextData', 'vocabulary', 'rnn', opt.vocab + '.pkl') rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb')) opt.vocab_size = len(rnn_vocab) # initialize word embedding opt.we_parameter = None if opt.word_dim == 500: w2v_data_path = os.path.join(rootpath, "word2vec", 'flickr', 'vec500flickr30m') opt.we_parameter = get_we_parameter(rnn_vocab, w2v_data_path) # mapping layer structure opt.text_mapping_layers = map(int, opt.text_mapping_layers.split('-')) opt.visual_mapping_layers = map(int, opt.visual_mapping_layers.split('-')) if opt.concate == 'full': opt.text_mapping_layers[ 0] = opt.bow_vocab_size + opt.text_rnn_size * 2 + opt.text_kernel_num * len( opt.text_kernel_sizes) opt.visual_mapping_layers[ 0] = opt.visual_feat_dim + opt.visual_rnn_size * 2 + opt.visual_kernel_num * len( opt.visual_kernel_sizes) elif opt.concate == 'reduced': opt.text_mapping_layers[ 0] = opt.text_rnn_size * 2 + opt.text_kernel_num * len( opt.text_kernel_sizes) opt.visual_mapping_layers[ 0] = opt.visual_rnn_size * 2 + opt.visual_kernel_num * len( opt.visual_kernel_sizes) else: raise NotImplementedError('Model %s not implemented' % opt.model) # set data loader video2frames = { x: read_dict( os.path.join(rootpath, collections[x], 'FeatureData', opt.visual_feature, 'video2frames.txt')) for x in collections } if testCollection.startswith('msvd'): data_loaders = data.get_train_data_loaders( caption_files, visual_feats, rnn_vocab, bow2vec, opt.batch_size, opt.workers, opt.n_caption, video2frames=video2frames, padding_size=opt.batch_padding) val_video_ids_list = data.read_video_ids(caption_files['val']) val_vid_data_loader = data.get_vis_data_loader( visual_feats['val'], opt.batch_size, opt.workers, video2frames['val'], video_ids=val_video_ids_list) val_text_data_loader = data.get_txt_data_loader( caption_files['val'], rnn_vocab, bow2vec, opt.batch_size, opt.workers) else: data_loaders = data.get_data_loaders(caption_files, visual_feats, rnn_vocab, bow2vec, opt.batch_size, opt.workers, opt.n_caption, video2frames=video2frames) print( "=======================Data Loaded=================================") # Construct the model model = get_model(opt.model)(opt) opt.we_parameter = None # optionally resume from a checkpoint if opt.resume: if os.path.isfile(opt.resume): print("=> loading checkpoint '{}'".format(opt.resume)) checkpoint = torch.load(opt.resume) start_epoch = checkpoint['epoch'] best_rsum = checkpoint['best_rsum'] model.load_state_dict(checkpoint['model']) # Eiters is used to show logs as the continuation of another # training model.Eiters = checkpoint['Eiters'] print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format( opt.resume, start_epoch, best_rsum)) if testCollection.startswith('msvd'): validate_split(opt, val_vid_data_loader, val_text_data_loader, model, measure=opt.measure) else: validate(opt, data_loaders['val'], model, measure=opt.measure) else: print("=> no checkpoint found at '{}'".format(opt.resume)) # Train the Model best_rsum = 0 no_impr_counter = 0 lr_counter = 0 best_epoch = None fout_val_metric_hist = open( os.path.join(opt.logger_name, 'val_metric_hist.txt'), 'w') loss_value = [] pos_value = [] neg_value = [] for epoch in range(opt.num_epochs): print('Epoch[{0} / {1}] LR: {2}'.format( epoch, opt.num_epochs, get_learning_rate(model.optimizer)[0])) print('-' * 10) # train for one epoch loss_t, pos_t, neg_t = train(opt, data_loaders['train'], model, epoch) loss_value.append(loss_t) pos_value.append(pos_t) neg_value.append(neg_t) # evaluate on validation set if testCollection.startswith('msvd'): rsum = validate_split(opt, val_vid_data_loader, val_text_data_loader, model, measure=opt.measure) else: rsum = validate(opt, data_loaders['val'], model, measure=opt.measure) # remember best R@ sum and save checkpoint is_best = rsum > best_rsum best_rsum = max(rsum, best_rsum) print(' * Current perf: {}'.format(rsum)) print(' * Best perf: {}'.format(best_rsum)) print('') fout_val_metric_hist.write('epoch_%d: %f\n' % (epoch, rsum)) fout_val_metric_hist.flush() if is_best: save_checkpoint( { 'epoch': epoch + 1, 'model': model.state_dict(), 'best_rsum': best_rsum, 'opt': opt, 'Eiters': model.Eiters, }, is_best, filename='checkpoint_epoch_%s.pth.tar' % epoch, prefix=opt.logger_name + '/', best_epoch=best_epoch) best_epoch = epoch lr_counter += 1 decay_learning_rate(opt, model.optimizer, opt.lr_decay_rate) if not is_best: # Early stop occurs if the validation performance does not improve in ten consecutive epochs no_impr_counter += 1 if no_impr_counter > 10: print('Early stopping happended.\n') break # When the validation performance decreased after an epoch, # we divide the learning rate by 2 and continue training; # but we use each learning rate for at least 3 epochs. if lr_counter > 2: decay_learning_rate(opt, model.optimizer, 0.5) lr_counter = 0 else: no_impr_counter = 0 #loss_value loss_value = np.array(loss_value) plt.title("Loss v. Time") plt.xlabel("Epoch") plt.xticks(np.arange(len(loss_value))) plt.ylabel("Loss Value") plt.plot(np.arange(len(loss_value)), loss_value) plt.savefig("./plots/training_loss_{}.png".format(opt.logtimestamp)) plt.close() #pos_value pos_value = np.array(pos_value) plt.title("Pos Score v. Time") plt.xlabel("Epoch") plt.xticks(np.arange(len(pos_value))) plt.ylabel("Pos Value") plt.plot(np.arange(len(pos_value)), pos_value) plt.savefig("./plots/training_pos_{}.png".format(opt.logtimestamp)) plt.close() #neg_value neg_value = np.array(neg_value) plt.title("Neg Score v. Time") plt.xlabel("Epoch") plt.xticks(np.arange(len(neg_value))) plt.ylabel("Neg Value") plt.plot(np.arange(len(neg_value)), neg_value) plt.savefig("./plots/training_neg_{}.png".format(opt.logtimestamp)) plt.close() fout_val_metric_hist.close() print('best performance on validation: {}\n'.format(best_rsum)) with open(os.path.join(opt.logger_name, 'val_metric.txt'), 'w') as fout: fout.write('best performance on validation: ' + str(best_rsum)) # generate evaluation shell script if testCollection == 'iacc.3': templete = ''.join(open('util/TEMPLATE_do_predict.sh').readlines()) striptStr = templete.replace('@@@query_sets@@@', 'tv16.avs.txt,tv17.avs.txt,tv18.avs.txt') else: templete = ''.join(open('util/TEMPLATE_do_test.sh').readlines()) striptStr = templete.replace('@@@n_caption@@@', str(opt.n_caption)) striptStr = striptStr.replace('@@@rootpath@@@', rootpath) striptStr = striptStr.replace('@@@testCollection@@@', testCollection) striptStr = striptStr.replace('@@@logger_name@@@', opt.logger_name) striptStr = striptStr.replace('@@@overwrite@@@', str(opt.overwrite)) # perform evaluation on test set runfile = 'do_test_%s_%s.sh' % (opt.model, testCollection) open(runfile, 'w').write(striptStr + '\n') os.system('chmod +x %s' % runfile)
def main(): opt = parse_args() print(json.dumps(vars(opt), indent=2)) rootpath = opt.rootpath testCollection = opt.testCollection n_caption = opt.n_caption resume = os.path.join(opt.logger_name, opt.checkpoint_name) if not os.path.exists(resume): logging.info(resume + ' not exists.') sys.exit(0) checkpoint = torch.load(resume) start_epoch = checkpoint['epoch'] best_rsum = checkpoint['best_rsum'] print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format( resume, start_epoch, best_rsum)) options = checkpoint['opt'] if not hasattr(options, 'concate'): setattr(options, "concate", "full") trainCollection = options.trainCollection output_dir = resume.replace(trainCollection, testCollection) output_dir = output_dir.replace('/%s/' % options.cv_name, '/results/%s/' % trainCollection) result_pred_sents = os.path.join(output_dir, 'id.sent.score.txt') pred_error_matrix_file = os.path.join(output_dir, 'pred_errors_matrix.pth.tar') if checkToSkip(pred_error_matrix_file, opt.overwrite): sys.exit(0) makedirsforfile(pred_error_matrix_file) # data loader prepare caption_files = { 'test': os.path.join(rootpath, testCollection, 'TextData', '%s.caption.txt' % testCollection) } img_feat_path = os.path.join(rootpath, testCollection, 'FeatureData', options.visual_feature) visual_feats = {'test': BigFile(img_feat_path)} assert options.visual_feat_dim == visual_feats['test'].ndims video2frames = { 'test': read_dict( os.path.join(rootpath, testCollection, 'FeatureData', options.visual_feature, 'video2frames.txt')) } # set bow vocabulary and encoding bow_vocab_file = os.path.join(rootpath, options.trainCollection, 'TextData', 'vocabulary', 'bow', options.vocab + '.pkl') bow_vocab = pickle.load(open(bow_vocab_file, 'rb')) bow2vec = get_text_encoder('bow')(bow_vocab) options.bow_vocab_size = len(bow_vocab) # set rnn vocabulary rnn_vocab_file = os.path.join(rootpath, options.trainCollection, 'TextData', 'vocabulary', 'rnn', options.vocab + '.pkl') rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb')) options.vocab_size = len(rnn_vocab) # Construct the model model = get_model(options.model)(options) model.load_state_dict(checkpoint['model']) model.Eiters = checkpoint['Eiters'] model.val_start() if testCollection.startswith( 'msvd'): # or testCollection.startswith('msrvtt'): # set data loader video_ids_list = data.read_video_ids(caption_files['test']) vid_data_loader = data.get_vis_data_loader(visual_feats['test'], opt.batch_size, opt.workers, video2frames['test'], video_ids=video_ids_list) text_data_loader = data.get_txt_data_loader(caption_files['test'], rnn_vocab, bow2vec, opt.batch_size, opt.workers) # mapping video_embs, video_ids = evaluation.encode_text_or_vid( model.embed_vis, vid_data_loader) cap_embs, caption_ids = evaluation.encode_text_or_vid( model.embed_txt, text_data_loader) else: # set data loader data_loader = data.get_test_data_loaders(caption_files, visual_feats, rnn_vocab, bow2vec, opt.batch_size, opt.workers, opt.n_caption, video2frames=video2frames) # mapping video_embs, cap_embs, video_ids, caption_ids = evaluation.encode_data( model, data_loader['test'], opt.log_step, logging.info) # remove duplicate videos idx = range(0, video_embs.shape[0], n_caption) video_embs = video_embs[idx, :] video_ids = video_ids[::opt.n_caption] c2i_all_errors = evaluation.cal_error(video_embs, cap_embs, options.measure) torch.save( { 'errors': c2i_all_errors, 'videos': video_ids, 'captions': caption_ids }, pred_error_matrix_file) print("write into: %s" % pred_error_matrix_file) if testCollection.startswith( 'msvd'): # or testCollection.startswith('msrvtt'): # caption retrieval (r1, r5, r10, medr, meanr, i2t_map_score) = evaluation.i2t_varied(c2i_all_errors, caption_ids, video_ids) # video retrieval (r1i, r5i, r10i, medri, meanri, t2i_map_score) = evaluation.t2i_varied(c2i_all_errors, caption_ids, video_ids) else: # caption retrieval (r1i, r5i, r10i, medri, meanri) = evaluation.t2i(c2i_all_errors, n_caption=n_caption) t2i_map_score = evaluation.t2i_map(c2i_all_errors, n_caption=n_caption) # video retrieval (r1, r5, r10, medr, meanr) = evaluation.i2t(c2i_all_errors, n_caption=n_caption) i2t_map_score = evaluation.i2t_map(c2i_all_errors, n_caption=n_caption) print(" * Text to Video:") print(" * r_1_5_10, medr, meanr: {}".format([ round(r1i, 1), round(r5i, 1), round(r10i, 1), round(medri, 1), round(meanri, 1) ])) print(" * recall sum: {}".format(round(r1i + r5i + r10i, 1))) print(" * mAP: {}".format(round(t2i_map_score, 3))) print(" * " + '-' * 10) # caption retrieval print(" * Video to text:") print(" * r_1_5_10, medr, meanr: {}".format([ round(r1, 1), round(r5, 1), round(r10, 1), round(medr, 1), round(meanr, 1) ])) print(" * recall sum: {}".format(round(r1 + r5 + r10, 1))) print(" * mAP: {}".format(round(i2t_map_score, 3))) print(" * " + '-' * 10)
def main(): opt = parse_args() logging.info(json.dumps(vars(opt), indent=2)) rootpath = opt.rootpath testCollection = opt.testCollection assert collectionStrt == "multiple" resume = os.path.join(opt.logger_name, opt.checkpoint_name) if not os.path.exists(resume): logging.info(resume + ' not exists.') sys.exit(0) checkpoint = torch.load(resume) start_epoch = checkpoint['epoch'] best_rsum = checkpoint['best_rsum'] logging.info("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format( resume, start_epoch, best_rsum)) options = checkpoint['opt'] trainCollection = options.trainCollection valCollection = options.valCollection visual_feat_file = BigFile( os.path.join(rootpath, testCollection, 'FeatureData', options.visual_feature)) assert options.visual_feat_dim == visual_feat_file.ndims video2frame = read_dict( os.path.join(rootpath, testCollection, 'FeatureData', options.visual_feature, 'video2frames.txt')) vid_data_loader = data.get_vis_data_loader(visual_feat_file, opt.batch_size, opt.workers, video2frame) vis_embs = None # set bow vocabulary and encoding bow_vocab_file = os.path.join(rootpath, options.trainCollection, 'TextData', 'vocabulary', 'bow', options.vocab + '.pkl') bow_vocab = pickle.load(open(bow_vocab_file, 'rb')) bow2vec = get_text_encoder('bow')(bow_vocab) options.bow_vocab_size = len(bow_vocab) # set rnn vocabulary rnn_vocab_file = os.path.join(rootpath, options.trainCollection, 'TextData', 'vocabulary', 'rnn', options.vocab + '.pkl') rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb')) options.vocab_size = len(rnn_vocab) model = get_model(options.model)(options) model.load_state_dict(checkpoint['model']) model.val_start() output_dir = resume.replace(trainCollection, testCollection) for query_set in opt.query_sets.strip().split(','): output_dir_tmp = output_dir.replace( valCollection, '%s/%s/%s' % (query_set, trainCollection, valCollection)) output_dir_tmp = output_dir_tmp.replace('/%s/' % options.cv_name, '/results/') pred_result_file = os.path.join(output_dir_tmp, 'id.sent.score.txt') logging.info(pred_result_file) if checkToSkip(pred_result_file, opt.overwrite): sys.exit(0) makedirsforfile(pred_result_file) # query data loader query_file = os.path.join(rootpath, testCollection, 'TextData', query_set) query_loader = data.get_txt_data_loader(query_file, rnn_vocab, bow2vec, opt.batch_size, opt.workers) # encode videos if vis_embs is None: start = time.time() if options.space == 'hybrid': video_embs, video_tag_probs, video_ids = evaluation.encode_text_or_vid_tag_hist_prob( model.embed_vis, vid_data_loader) else: video_embs, video_ids = evaluation.encode_text_or_vid( model.embed_vis, vid_data_loader) logging.info("encode video time: %.3f s" % (time.time() - start)) # encode text start = time.time() if options.space == 'hybrid': query_embs, query_tag_probs, query_ids = evaluation.encode_text_or_vid_tag_hist_prob( model.embed_txt, query_loader) else: query_embs, query_ids = evaluation.encode_text_or_vid( model.embed_txt, query_loader) logging.info("encode text time: %.3f s" % (time.time() - start)) if options.space == 'hybrid': t2v_matrix_1 = evaluation.cal_simi(query_embs, video_embs) # eval_avs(t2v_matrix_1, query_ids, video_ids, pred_result_file, rootpath, testCollection, query_set) t2v_matrix_2 = evaluation.cal_simi(query_tag_probs, video_tag_probs) # pred_result_file = os.path.join(output_dir_tmp, 'id.sent.score_2.txt') # eval_avs(t2v_matrix_2, query_ids, video_ids, pred_result_file, rootpath, testCollection, query_set) t2v_matrix_1 = norm_score(t2v_matrix_1) t2v_matrix_2 = norm_score(t2v_matrix_2) for w in [0.8]: print("\n") t2v_matrix = w * t2v_matrix_1 + (1 - w) * t2v_matrix_2 pred_result_file = os.path.join(output_dir_tmp, 'id.sent.score_%.1f.txt' % w) eval_avs(t2v_matrix, query_ids, video_ids, pred_result_file, rootpath, testCollection, query_set) else: t2v_matrix_1 = evaluation.cal_simi(query_embs, video_embs) eval_avs(t2v_matrix_1, query_ids, video_ids, pred_result_file, rootpath, testCollection, query_set)
def main(): opt = parse_args() print(json.dumps(vars(opt), indent=2)) rootpath = opt.rootpath collectionStrt = opt.collectionStrt resume = os.path.join(opt.logger_name, opt.checkpoint_name) if not os.path.exists(resume): logging.info(resume + ' not exists.') sys.exit(0) checkpoint = torch.load(resume) start_epoch = checkpoint['epoch'] best_rsum = checkpoint['best_rsum'] print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format( resume, start_epoch, best_rsum)) options = checkpoint['opt'] # collection setting testCollection = opt.testCollection collections_pathname = options.collections_pathname collections_pathname['test'] = testCollection trainCollection = options.trainCollection output_dir = resume.replace(trainCollection, testCollection) if 'checkpoints' in output_dir: output_dir = output_dir.replace('/checkpoints/', '/results/') else: output_dir = output_dir.replace( '/%s/' % options.cv_name, '/results/%s/%s/' % (options.cv_name, trainCollection)) result_pred_sents = os.path.join(output_dir, 'id.sent.score.txt') pred_error_matrix_file = os.path.join(output_dir, 'pred_errors_matrix.pth.tar') if checkToSkip(pred_error_matrix_file, opt.overwrite): sys.exit(0) makedirsforfile(pred_error_matrix_file) log_config(output_dir) logging.info(json.dumps(vars(opt), indent=2)) # data loader prepare test_cap = os.path.join(rootpath, collections_pathname['test'], 'TextData', '%s.caption.txt' % testCollection) if collectionStrt == 'single': test_cap = os.path.join( rootpath, collections_pathname['test'], 'TextData', '%s%s.caption.txt' % (testCollection, opt.split)) elif collectionStrt == 'multiple': test_cap = os.path.join(rootpath, collections_pathname['test'], 'TextData', '%s.caption.txt' % testCollection) else: raise NotImplementedError('collection structure %s not implemented' % collectionStrt) caption_files = {'test': test_cap} img_feat_path = os.path.join(rootpath, collections_pathname['test'], 'FeatureData', options.visual_feature) visual_feats = {'test': BigFile(img_feat_path)} assert options.visual_feat_dim == visual_feats['test'].ndims video2frames = { 'test': read_dict( os.path.join(rootpath, collections_pathname['test'], 'FeatureData', options.visual_feature, 'video2frames.txt')) } # set bow vocabulary and encoding bow_vocab_file = os.path.join(rootpath, collections_pathname['train'], 'TextData', 'vocabulary', 'bow', options.vocab + '.pkl') bow_vocab = pickle.load(open(bow_vocab_file, 'rb')) bow2vec = get_text_encoder('bow')(bow_vocab) options.bow_vocab_size = len(bow_vocab) # set rnn vocabulary rnn_vocab_file = os.path.join(rootpath, collections_pathname['train'], 'TextData', 'vocabulary', 'rnn', options.vocab + '.pkl') rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb')) options.vocab_size = len(rnn_vocab) # Construct the model model = get_model(options.model)(options) model.load_state_dict(checkpoint['model']) model.Eiters = checkpoint['Eiters'] model.val_start() # set data loader video_ids_list = data.read_video_ids(caption_files['test']) vid_data_loader = data.get_vis_data_loader(visual_feats['test'], opt.batch_size, opt.workers, video2frames['test'], video_ids=video_ids_list) text_data_loader = data.get_txt_data_loader(caption_files['test'], rnn_vocab, bow2vec, opt.batch_size, opt.workers) # mapping if options.space == 'hybrid': video_embs, video_tag_probs, video_ids = evaluation.encode_text_or_vid_tag_hist_prob( model.embed_vis, vid_data_loader) cap_embs, cap_tag_probs, caption_ids = evaluation.encode_text_or_vid_tag_hist_prob( model.embed_txt, text_data_loader) else: video_embs, video_ids = evaluation.encode_text_or_vid( model.embed_vis, vid_data_loader) cap_embs, caption_ids = evaluation.encode_text_or_vid( model.embed_txt, text_data_loader) v2t_gt, t2v_gt = metrics.get_gt(video_ids, caption_ids) logging.info("write into: %s" % output_dir) if options.space != 'latent': tag_vocab_path = os.path.join( rootpath, collections_pathname['train'], 'TextData', 'tags', 'video_label_th_1', 'tag_vocab_%d.json' % options.tag_vocab_size) evaluation.pred_tag(video_tag_probs, video_ids, tag_vocab_path, os.path.join(output_dir, 'video')) evaluation.pred_tag(cap_tag_probs, caption_ids, tag_vocab_path, os.path.join(output_dir, 'text')) if options.space in ['latent', 'hybrid']: # logging.info("=======Latent Space=======") t2v_all_errors_1 = evaluation.cal_error(video_embs, cap_embs, options.measure) if options.space in ['concept', 'hybrid']: # logging.info("=======Concept Space=======") t2v_all_errors_2 = evaluation.cal_error_batch(video_tag_probs, cap_tag_probs, options.measure_2) if options.space in ['hybrid']: w = 0.6 t2v_all_errors_1 = norm_score(t2v_all_errors_1) t2v_all_errors_2 = norm_score(t2v_all_errors_2) t2v_tag_all_errors = w * t2v_all_errors_1 + (1 - w) * t2v_all_errors_2 cal_perf(t2v_tag_all_errors, v2t_gt, t2v_gt) torch.save( { 'errors': t2v_tag_all_errors, 'videos': video_ids, 'captions': caption_ids }, pred_error_matrix_file) logging.info("write into: %s" % pred_error_matrix_file) elif options.space in ['latent']: cal_perf(t2v_all_errors_1, v2t_gt, t2v_gt) torch.save( { 'errors': t2v_all_errors_1, 'videos': video_ids, 'captions': caption_ids }, pred_error_matrix_file) logging.info("write into: %s" % pred_error_matrix_file)
def main(): opt = parse_args() print(json.dumps(vars(opt), indent=2)) rootpath = opt.rootpath testCollection = opt.testCollection resume = os.path.join(opt.logger_name, opt.checkpoint_name) if not os.path.exists(resume): logging.info(resume + ' not exists.') sys.exit(0) checkpoint = torch.load(resume) start_epoch = checkpoint['epoch'] best_rsum = checkpoint['best_rsum'] print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})" .format(resume, start_epoch, best_rsum)) options = checkpoint['opt'] if not hasattr(options, 'concate'): setattr(options, "concate", "full") model = get_model(options.model)(options) model.load_state_dict(checkpoint['model']) model.val_start() trainCollection = options.trainCollection valCollection = options.valCollection visual_feat_file = BigFile(os.path.join(rootpath, testCollection, 'FeatureData', options.visual_feature)) assert options.visual_feat_dim == visual_feat_file.ndims video2frame = read_dict(os.path.join(rootpath, testCollection, 'FeatureData', options.visual_feature, 'video2frames.txt')) visual_loader = data.get_vis_data_loader(visual_feat_file, opt.batch_size, opt.workers, video2frame) vis_embs = None # set bow vocabulary and encoding bow_vocab_file = os.path.join(rootpath, options.trainCollection, 'TextData', 'vocabulary', 'bow', options.vocab+'.pkl') bow_vocab = pickle.load(open(bow_vocab_file, 'rb')) bow2vec = get_text_encoder('bow')(bow_vocab) options.bow_vocab_size = len(bow_vocab) # set rnn vocabulary rnn_vocab_file = os.path.join(rootpath, options.trainCollection, 'TextData', 'vocabulary', 'rnn', options.vocab+'.pkl') rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb')) options.vocab_size = len(rnn_vocab) output_dir = resume.replace(trainCollection, testCollection) for query_set in opt.query_sets.strip().split(','): output_dir_tmp = output_dir.replace(valCollection, '%s/%s/%s' % (query_set, trainCollection, valCollection)) output_dir_tmp = output_dir_tmp.replace('/%s/' % options.cv_name, '/results/') pred_result_file = os.path.join(output_dir_tmp, 'id.sent.score.txt') print(pred_result_file) if checkToSkip(pred_result_file, opt.overwrite): continue try: makedirsforfile(pred_result_file) except Exception as e: print(e) # data loader prepare query_file = os.path.join(rootpath, testCollection, 'TextData', query_set) # set data loader query_loader = data.get_txt_data_loader(query_file, rnn_vocab, bow2vec, opt.batch_size, opt.workers) if vis_embs is None: start = time.time() vis_embs, vis_ids = encode_data(model.embed_vis, visual_loader) print("encode image time: %.3f s" % (time.time()-start)) start = time.time() query_embs, query_ids = encode_data(model.embed_txt, query_loader) print("encode text time: %.3f s" % (time.time()-start)) start = time.time() t2i_matrix = query_embs.dot(vis_embs.T) inds = np.argsort(t2i_matrix, axis=1) print("compute similarity time: %.3f s" % (time.time()-start)) with open(pred_result_file, 'w') as fout: for index in range(inds.shape[0]): ind = inds[index][::-1] fout.write(query_ids[index]+' '+' '.join([vis_ids[i]+' %s'%t2i_matrix[index][i] for i in ind])+'\n') if testCollection == 'iacc.3': templete = ''.join(open( 'tv-avs-eval/TEMPLATE_do_eval.sh').readlines()) striptStr = templete.replace('@@@rootpath@@@', rootpath) striptStr = striptStr.replace('@@@testCollection@@@', testCollection) striptStr = striptStr.replace('@@@topic_set@@@', query_set.split('.')[0]) striptStr = striptStr.replace('@@@overwrite@@@', str(opt.overwrite)) striptStr = striptStr.replace('@@@score_file@@@', pred_result_file) runfile = 'do_eval_%s.sh' % testCollection open(os.path.join('tv-avs-eval', runfile), 'w').write(striptStr + '\n') os.system('cd tv-avs-eval; chmod +x %s; bash %s; cd -' % (runfile, runfile))