Beispiel #1
0
def score_speech_text(net, dataset):
    data = dataset.evaluation()
    audio_e = net.embed_audio(data['audio'])
    text_e = net.embed_text(data['text'])
    correct = torch.eye(len(data['audio'])).type(torch.bool)
    result = E.ranking(audio_e, text_e, correct)
    return dict(medr=np.median(result['ranks']),
                recall={1: np.mean(result['recall'][1]),
                        5: np.mean(result['recall'][5]),
                       10: np.mean(result['recall'][10])})
Beispiel #2
0
def score(net, dataset):
    data = dataset.evaluation()
    correct = data['correct'].cpu().numpy()
    image_e = net.embed_image(data['image'])
    audio_e = net.embed_audio(data['audio'])
    result = E.ranking(image_e, audio_e, correct)
    return dict(medr=np.median(result['ranks']),
                recall={1: np.mean(result['recall'][1]),
                        5: np.mean(result['recall'][5]),
                       10: np.mean(result['recall'][10])})
Beispiel #3
0
                         label_encoder=fd.get_label_encoder(),
                         language='en'),
                    open('config.pkl', 'wb'))

    if args.text_image_model_dir:
        net_fname = 'net_{}.best.pt'.format(ds_factor)
        net = torch.load(os.path.join(args.text_image_model_dir, net_fname))
    else:
        logging.info('Building model text-image')
        net = M2.TextImage(M2.get_default_config())
        run_config = dict(max_lr=2 * 1e-4, epochs=32)
        logging.info('Training text-image')
        M2.experiment(net, data, run_config)
        suffix = str(ds_factor).zfill(lz)
        res_fname = 'result_text_image_{}.json'.format(suffix)
        copyfile('result.json', res_fname)
        net_fname = 'ti_{}.best.pt'.format(ds_factor)
        copy_best(res_fname, net_fname)
        net = torch.load(net_fname)

    logging.info('Evaluating text-image with ASR\'s output')
    data = data['val'].dataset.evaluation()
    correct = data['correct'].cpu().numpy()
    image_e = net.embed_image(data['image'])
    text_e = net.embed_text(hyp_asr)
    result = E.ranking(image_e, text_e, correct)
    print(dict(medr=np.median(result['ranks']),
               recall={1: np.mean(result['recall'][1]),
                       5: np.mean(result['recall'][5]),
                       10: np.mean(result['recall'][10])}))