def score_speech_text(net, dataset): data = dataset.evaluation() audio_e = net.embed_audio(data['audio']) text_e = net.embed_text(data['text']) correct = torch.eye(len(data['audio'])).type(torch.bool) result = E.ranking(audio_e, text_e, correct) return dict(medr=np.median(result['ranks']), recall={1: np.mean(result['recall'][1]), 5: np.mean(result['recall'][5]), 10: np.mean(result['recall'][10])})
def score(net, dataset): data = dataset.evaluation() correct = data['correct'].cpu().numpy() image_e = net.embed_image(data['image']) audio_e = net.embed_audio(data['audio']) result = E.ranking(image_e, audio_e, correct) return dict(medr=np.median(result['ranks']), recall={1: np.mean(result['recall'][1]), 5: np.mean(result['recall'][5]), 10: np.mean(result['recall'][10])})
label_encoder=fd.get_label_encoder(), language='en'), open('config.pkl', 'wb')) if args.text_image_model_dir: net_fname = 'net_{}.best.pt'.format(ds_factor) net = torch.load(os.path.join(args.text_image_model_dir, net_fname)) else: logging.info('Building model text-image') net = M2.TextImage(M2.get_default_config()) run_config = dict(max_lr=2 * 1e-4, epochs=32) logging.info('Training text-image') M2.experiment(net, data, run_config) suffix = str(ds_factor).zfill(lz) res_fname = 'result_text_image_{}.json'.format(suffix) copyfile('result.json', res_fname) net_fname = 'ti_{}.best.pt'.format(ds_factor) copy_best(res_fname, net_fname) net = torch.load(net_fname) logging.info('Evaluating text-image with ASR\'s output') data = data['val'].dataset.evaluation() correct = data['correct'].cpu().numpy() image_e = net.embed_image(data['image']) text_e = net.embed_text(hyp_asr) result = E.ranking(image_e, text_e, correct) print(dict(medr=np.median(result['ranks']), recall={1: np.mean(result['recall'][1]), 5: np.mean(result['recall'][5]), 10: np.mean(result['recall'][10])}))