Beispiel #1
0
torch.manual_seed(123)

batch_size = 32
hidden_size = 1024
dropout = 0.0
feature_fname = 'mfcc_delta_features.pt'

logging.basicConfig(level=logging.INFO)

factors = [3, 9, 27, 81, 243]
lz = len(str(abs(factors[-1])))
for ds_factor in factors:
    logging.info('Loading data')
    data = dict(train=D.flickr8k_loader(split='train',
                                        batch_size=batch_size,
                                        shuffle=True,
                                        feature_fname=feature_fname,
                                        downsampling_factor=ds_factor),
                val=D.flickr8k_loader(split='val',
                                      batch_size=batch_size,
                                      shuffle=False,
                                      feature_fname=feature_fname))
    fd = D.Flickr8KData
    fd.init_vocabulary(data['train'].dataset)

    # Saving config
    pickle.dump(
        dict(feature_fname=feature_fname,
             label_encoder=fd.get_label_encoder(),
             language='en'), open('config.pkl', 'wb'))
Beispiel #2
0
import platalea.asr as M
import platalea.dataset as D

torch.manual_seed(123)

batch_size = 8
hidden_size = 1024
dropout = 0.0
feature_fname = 'mfcc_delta_features.pt'

logging.basicConfig(level=logging.INFO)

logging.info('Loading data')
data = dict(train=D.flickr8k_loader(split='train',
                                    batch_size=batch_size,
                                    shuffle=True,
                                    feature_fname=feature_fname),
            val=D.flickr8k_loader(split='val',
                                  batch_size=batch_size,
                                  shuffle=False,
                                  feature_fname=feature_fname))
fd = D.Flickr8KData
fd.init_vocabulary(data['train'].dataset)

# Saving config
pickle.dump(
    dict(feature_fname=feature_fname,
         label_encoder=fd.get_label_encoder(),
         language='en'), open('config.pkl', 'wb'))

config = dict(
Beispiel #3
0
torch.manual_seed(args.seed)
random.seed(args.seed)
logging.basicConfig(level=logging.INFO)

# Logging the arguments
logging.info('Arguments: {}'.format(args))


batch_size = 8
hidden_size = 1024
dropout = 0.0

logging.info('Loading data')
data = dict(
    train=D.flickr8k_loader(
        args.flickr8k_root, args.flickr8k_meta, args.flickr8k_language,
        args.audio_features_fn, split='train', batch_size=batch_size,
        shuffle=True, downsampling_factor=args.downsampling_factor),
    val=D.flickr8k_loader(
        args.flickr8k_root, args.flickr8k_meta, args.flickr8k_language,
        args.audio_features_fn, split='val', batch_size=batch_size,
        shuffle=False))

if args.downsampling_factor_text:
    ds_factor_text = args.downsampling_factor_text
    step_st = args.downsampling_factor_text
    # The downsampling factor for text is applied on top of the main
    # downsampling factor that is applied to all data
    if args.downsampling_factor:
        ds_factor_text *= args.downsampling_factor
    data_st = dict(
        train=D.flickr8k_loader(
args.enable_help()
args.parse()

# Setting general configuration
torch.manual_seed(args.seed)
random.seed(args.seed)

batch_size = 32
hidden_size = 1024
dropout = 0.0

logging.info('Loading data')
data = dict(train=D.flickr8k_loader(args.flickr8k_root,
                                    args.flickr8k_meta,
                                    args.flickr8k_language,
                                    args.audio_features_fn,
                                    split='train',
                                    batch_size=batch_size,
                                    shuffle=True),
            val=D.flickr8k_loader(args.flickr8k_root,
                                  args.flickr8k_meta,
                                  args.flickr8k_language,
                                  args.audio_features_fn,
                                  split='val',
                                  batch_size=batch_size,
                                  shuffle=False))

logging.info('Building model')
net = M.TextImage(M.get_default_config())
run_config = dict(max_lr=2 * 1e-4, epochs=args.epochs)
Beispiel #5
0
import platalea.asr as M
import platalea.dataset as D

torch.manual_seed(123)


batch_size = 8
hidden_size = 1024
dropout = 0.0
feature_fname = 'mfcc_delta_features.pt'

logging.basicConfig(level=logging.INFO)

logging.info('Loading data')
data = dict(
    train=D.flickr8k_loader(split='train', batch_size=batch_size, shuffle=True,
                            feature_fname=feature_fname, language='jp'),
    val=D.flickr8k_loader(split='val', batch_size=batch_size, shuffle=False,
                          feature_fname=feature_fname, language='jp'))
fd = D.Flickr8KData
fd.init_vocabulary(data['train'].dataset)

# Saving config
pickle.dump(dict(feature_fname=feature_fname,
                 label_encoder=fd.get_label_encoder(),
                 language='jp'),
            open('config.pkl', 'wb'))

config = dict(
    SpeechEncoder=dict(
        conv=dict(in_channels=39, out_channels=64, kernel_size=6, stride=2,
                  padding=0, bias=False),
Beispiel #6
0
parser = configargparse.get_argument_parser('platalea')
parser.add_argument(
    '--epochs',
    action='store',
    default=32,
    dest='epochs',
    type=int,
    help='number of epochs after which to stop training (default: 32)')

config_args, unknown_args = parser.parse_known_args()

logging.basicConfig(level=logging.INFO)

logging.info('Loading data')
data = dict(train=D.flickr8k_loader(split='train', batch_size=32,
                                    shuffle=True),
            val=D.flickr8k_loader(split='val', batch_size=32, shuffle=False))
D.Flickr8KData.init_vocabulary(data['train'].dataset)

config = dict(SpeechEncoder=dict(conv=dict(in_channels=39,
                                           out_channels=64,
                                           kernel_size=6,
                                           stride=2,
                                           padding=0,
                                           bias=False),
                                 rnn=dict(input_size=64,
                                          hidden_size=1024,
                                          num_layers=4,
                                          bidirectional=True,
                                          dropout=0),
                                 att=dict(in_size=2048, hidden_size=128)),
    logging.basicConfig(level=logging.INFO)

    # Parse command line parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('path', metavar='path', help='Model\'s path')
    parser.add_argument('-b', help='Use beam decoding', dest='use_beam_decoding',
                        action='store_true', default=False)
    args = parser.parse_args()

    # Loading config
    conf = pickle.load(open('config.pkl', 'rb'))

    logging.info('Loading data')
    data = dict(
        train=D.flickr8k_loader(split='train', batch_size=batch_size,
                                shuffle=False, feature_fname=conf['feature_fname'],
                                language=conf['language']),
        val=D.flickr8k_loader(split='val', batch_size=batch_size, shuffle=False,
                              feature_fname=conf['feature_fname'],
                              language=conf['language']))
    fd = D.Flickr8KData
    fd.le = conf['label_encoder']

    net = torch.load(args.path)




    trn = {}
    logging.info('Extracting transcriptions')
    for set_id, set_name in [('train', 'Training'), ('val', 'Validation')]:
Beispiel #8
0
import torch

import platalea.basic as M
import platalea.dataset as D
from utils.copy_best import copy_best

torch.manual_seed(123)

logging.basicConfig(level=logging.INFO)

factors = [3, 9, 27, 81, 243]
lz = len(str(abs(factors[-1])))
for ds_factor in factors:
    logging.info('Loading data')
    data = dict(train=D.flickr8k_loader(split='train',
                                        batch_size=32,
                                        shuffle=True,
                                        downsampling_factor=ds_factor),
                val=D.flickr8k_loader(split='val',
                                      batch_size=32,
                                      shuffle=False))
    D.Flickr8KData.init_vocabulary(data['train'].dataset)

    config = dict(SpeechEncoder=dict(conv=dict(in_channels=39,
                                               out_channels=64,
                                               kernel_size=6,
                                               stride=2,
                                               padding=0,
                                               bias=False),
                                     rnn=dict(input_size=64,
                                              hidden_size=1024,
                                              num_layers=4,
Beispiel #9
0
def train(args):
    # Setting general configuration
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # Logging the arguments
    logging.info('Arguments: {}'.format(args))

    logging.info('Loading data')

    if args.dataset_name == 'flickr8k':
        data = dict(train=D.flickr8k_loader(
            args.flickr8k_root,
            args.flickr8k_meta,
            args.flickr8k_language,
            args.audio_features_fn,
            split='train',
            batch_size=32,
            shuffle=True,
            downsampling_factor=args.downsampling_factor),
                    val=D.flickr8k_loader(args.flickr8k_root,
                                          args.flickr8k_meta,
                                          args.flickr8k_language,
                                          args.audio_features_fn,
                                          split='val',
                                          batch_size=32,
                                          shuffle=False))
    elif args.dataset_name == "spokencoco":
        data = dict(train=D.spokencoco_loader(
            args.spokencoco_root,
            args.spokencoco_meta,
            args.audio_features_fn,
            split='train',
            batch_size=32,
            shuffle=True,
            downsampling_factor=args.downsampling_factor,
            debug=args.debug),
                    val=D.spokencoco_loader(args.spokencoco_root,
                                            args.spokencoco_meta,
                                            args.audio_features_fn,
                                            split='val',
                                            batch_size=32,
                                            shuffle=False,
                                            debug=args.debug))
    else:
        raise ValueError(
            "dataset_name should be in ['flickr8k', 'spokencoco']")

    config = dict(
        SpeechEncoder=dict(conv=dict(in_channels=39,
                                     out_channels=64,
                                     kernel_size=6,
                                     stride=2,
                                     padding=0,
                                     bias=False),
                           rnn=dict(input_size=64,
                                    hidden_size=args.hidden_size_factor,
                                    num_layers=4,
                                    bidirectional=True,
                                    dropout=0),
                           att=dict(in_size=2 * args.hidden_size_factor,
                                    hidden_size=128)),
        ImageEncoder=dict(linear=dict(in_size=2048,
                                      out_size=2 * args.hidden_size_factor),
                          norm=True),
        margin_size=0.2)

    logging.info('Building model')
    net = M.SpeechImage(config)
    run_config = dict(
        max_lr=args.cyclic_lr_max,
        min_lr=args.cyclic_lr_min,
        epochs=args.epochs,
        l2_regularization=args.l2_regularization,
    )

    logging.info('Training')
    old_time = datetime.datetime.now()
    logging.info(f'Start of training - {old_time}')
    M.experiment(net, data, run_config, wandb_mode='disabled')
    new_time = datetime.datetime.now()
    logging.info(f'End of training - {new_time}')
    diff_time = new_time - old_time
    logging.info(f'Total duration: {diff_time}')
Beispiel #10
0
# Parse command line parameters
parser = argparse.ArgumentParser()
parser.add_argument('path', metavar='path', help='Model\'s path')
parser.add_argument('-b',
                    help='Use beam decoding',
                    dest='use_beam_decoding',
                    action='store_true',
                    default=False)
args = parser.parse_args()

# Loading config
conf = pickle.load(open('config.pkl', 'rb'))

logging.info('Loading data')
data = dict(val=D.flickr8k_loader(split='val',
                                  batch_size=batch_size,
                                  shuffle=False,
                                  feature_fname=conf['feature_fname']))
fd = D.Flickr8KData
fd.le = conf['label_encoder']

logging.info('Loading model')
net = torch.load(args.path)
logging.info('Evaluating')
with torch.no_grad():
    net.eval()
    if args.use_beam_decoding:
        result = platalea.score.score_asr(net,
                                          data['val'].dataset,
                                          beam_size=10)
    else:
        result = platalea.score.score_asr(net, data['val'].dataset)
Beispiel #11
0
                  dest='use_test_set',
                  action='store_true',
                  default=False)
args.enable_help()
args.parse()

batch_size = 16

logging.basicConfig(level=logging.INFO)

logging.info('Loading data')
if args.use_test_set:
    data = D.flickr8k_loader(args.flickr8k_root,
                             args.flickr8k_meta,
                             args.flickr8k_language,
                             args.audio_features_fn,
                             split='test',
                             batch_size=batch_size,
                             shuffle=False)
else:
    data = D.flickr8k_loader(args.flickr8k_root,
                             args.flickr8k_meta,
                             args.flickr8k_language,
                             args.audio_features_fn,
                             split='val',
                             batch_size=batch_size,
                             shuffle=False)

logging.info('Loading model')
net = torch.load(args.path)
logging.info('Evaluating')