Ejemplo n.º 1
0
hidden_size = 1024
dropout = 0.0
feature_fname = 'mfcc_delta_features.pt'

logging.basicConfig(level=logging.INFO)

logging.info('Loading data')
data = dict(train=D.flickr8k_loader(split='train',
                                    batch_size=batch_size,
                                    shuffle=True,
                                    feature_fname=feature_fname),
            val=D.flickr8k_loader(split='val',
                                  batch_size=batch_size,
                                  shuffle=False,
                                  feature_fname=feature_fname))
fd = D.Flickr8KData
fd.init_vocabulary(data['train'].dataset)

# Saving config
pickle.dump(
    dict(feature_fname=feature_fname,
         label_encoder=fd.get_label_encoder(),
         language='en'), open('config.pkl', 'wb'))

logging.info('Building model')
net = M.TextImage(M.get_default_config())
run_config = dict(max_lr=2 * 1e-4, epochs=32)

logging.info('Training')
M.experiment(net, data, run_config)
Ejemplo n.º 2
0
        # Replacing original transcriptions with ASR's output
        for i in range(len(hyp_asr)):
            item = ds.split_data[i]
            if item[2] == ref_asr[i]:
                ds.split_data[i] = (item[0], item[1], hyp_asr[i])
            else:
                msg = 'Extracted reference #{} ({}) doesn\'t match dataset\'s \
                        one ({}) for {} set.'

                msg = msg.format(i, ref_asr[i], ds.split_data[i][3], set_name)
                logging.warning(msg)

    if args.asr_model_dir:
        # Saving config for text-image model
        pickle.dump(
            dict(feature_fname=feature_fname,
                 label_encoder=fd.get_label_encoder(),
                 language='en'), open('config.pkl', 'wb'))

    logging.info('Building model text-image')
    net = M2.TextImage(M2.get_default_config())
    run_config = dict(max_lr=2 * 1e-4, epochs=32)

    logging.info('Training text-image')
    M2.experiment(net, data, run_config)
    suffix = str(ds_factor).zfill(lz)
    res_fname = 'result_text_image_{}.json'.format(suffix)
    copyfile('result.json', res_fname)
    net_fname = 'ti_{}.best.pt'.format(ds_factor)
    copy_best(res_fname, net_fname)