hidden_size = 1024 dropout = 0.0 feature_fname = 'mfcc_delta_features.pt' logging.basicConfig(level=logging.INFO) logging.info('Loading data') data = dict(train=D.flickr8k_loader(split='train', batch_size=batch_size, shuffle=True, feature_fname=feature_fname), val=D.flickr8k_loader(split='val', batch_size=batch_size, shuffle=False, feature_fname=feature_fname)) fd = D.Flickr8KData fd.init_vocabulary(data['train'].dataset) # Saving config pickle.dump( dict(feature_fname=feature_fname, label_encoder=fd.get_label_encoder(), language='en'), open('config.pkl', 'wb')) logging.info('Building model') net = M.TextImage(M.get_default_config()) run_config = dict(max_lr=2 * 1e-4, epochs=32) logging.info('Training') M.experiment(net, data, run_config)
# Replacing original transcriptions with ASR's output for i in range(len(hyp_asr)): item = ds.split_data[i] if item[2] == ref_asr[i]: ds.split_data[i] = (item[0], item[1], hyp_asr[i]) else: msg = 'Extracted reference #{} ({}) doesn\'t match dataset\'s \ one ({}) for {} set.' msg = msg.format(i, ref_asr[i], ds.split_data[i][3], set_name) logging.warning(msg) if args.asr_model_dir: # Saving config for text-image model pickle.dump( dict(feature_fname=feature_fname, label_encoder=fd.get_label_encoder(), language='en'), open('config.pkl', 'wb')) logging.info('Building model text-image') net = M2.TextImage(M2.get_default_config()) run_config = dict(max_lr=2 * 1e-4, epochs=32) logging.info('Training text-image') M2.experiment(net, data, run_config) suffix = str(ds_factor).zfill(lz) res_fname = 'result_text_image_{}.json'.format(suffix) copyfile('result.json', res_fname) net_fname = 'ti_{}.best.pt'.format(ds_factor) copy_best(res_fname, net_fname)