Ejemplo n.º 1
0
Archivo: test.py Proyecto: uniq10/nabu
def main(_):
    '''does everything for testing'''

    decoder_cfg_file = None

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(FLAGS.asr_expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(
        os.path.join(FLAGS.asr_expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_asr_cfg = configparser.ConfigParser()
    parsed_asr_cfg.read(os.path.join(FLAGS.asr_expdir, 'model', 'asr.cfg'))
    asr_cfg = dict(parsed_asr_cfg.items('asr'))

    #read the lm config file
    parsed_lm_cfg = configparser.ConfigParser()
    parsed_lm_cfg.read(os.path.join(FLAGS.lm_expdir, 'model', 'lm.cfg'))
    lm_cfg = dict(parsed_lm_cfg.items('lm'))

    #read the asr-lm config file
    parsed_asr_lm_cfg = configparser.ConfigParser()
    parsed_asr_lm_cfg.read('config/asr_lm.cfg')
    asr_lm_cfg = dict(parsed_asr_lm_cfg.items('asr-lm'))

    #read the decoder config file
    if decoder_cfg_file is None:
        decoder_cfg_file = os.path.join(FLAGS.asr_expdir, 'model',
                                        'decoder.cfg')
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(decoder_cfg_file)
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    #create a feature reader
    featdir = os.path.join(database_cfg['test_dir'], feat_cfg['name'])

    with open(os.path.join(featdir, 'maxlength'), 'r') as fid:
        max_length = int(fid.read())

    reader = feature_reader.FeatureReader(
        scpfile=os.path.join(featdir, 'feats.scp'),
        cmvnfile=os.path.join(featdir, 'cmvn.scp'),
        utt2spkfile=os.path.join(featdir, 'utt2spk'),
        max_length=max_length)

    #read the feature dimension
    with open(
        os.path.join(database_cfg['train_dir'], feat_cfg['name'],
                     'dim'),
        'r') as fid:

        input_dim = int(fid.read())

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)


    #create the classifier
    classifier = asr_lm_classifier.AsrLmClassifier(
        conf=asr_lm_cfg,
        asr_conf=asr_cfg,
        lm_conf=lm_cfg,
        output_dim=coder.num_labels)

    #create a decoder
    graph = tf.Graph()
    with graph.as_default():
        decoder = decoder_factory.factory(
            conf=decoder_cfg,
            classifier=classifier,
            input_dim=input_dim,
            max_input_length=reader.max_length,
            coder=coder,
            expdir=FLAGS.asr_expdir)


        #create the lm saver
        varnames = zip(*checkpoint_utils.list_variables(os.path.join(
            FLAGS.lm_expdir, 'model', 'network.ckpt')))[0]
        variables = [v for v in tf.all_variables()
                     if v.name.split(':')[0] in varnames]
        lm_saver = tf.train.Saver(variables)

        #create the asr saver
        varnames = zip(*checkpoint_utils.list_variables(os.path.join(
            FLAGS.asr_expdir, 'model', 'network.ckpt')))[0]
        variables = [v for v in tf.all_variables()
                     if v.name.split(':')[0] in varnames]
        asr_saver = tf.train.Saver(variables)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True #pylint: disable=E1101
    config.allow_soft_placement = True

    with tf.Session(graph=graph, config=config) as sess:
        #load the lm model
        lm_saver.restore(
            sess, os.path.join(FLAGS.lm_expdir, 'model', 'network.ckpt'))

        #load the asr model
        asr_saver.restore(
            sess, os.path.join(FLAGS.asr_expdir, 'model', 'network.ckpt'))

        #decode with te neural net
        decoded = decoder.decode(reader, sess)

    #the path to the text file
    textfile = database_cfg['testtext']

    #read all the reference transcriptions
    with open(textfile) as fid:
        lines = fid.readlines()

    references = dict()
    for line in lines:
        splitline = line.strip().split(' ')
        references[splitline[0]] = ' '.join(splitline[1:])

    #compute the character error rate
    score = decoder.score(decoded, references)

    print 'score: %f' % score
Ejemplo n.º 2
0
def main(_):
    '''does everything for testing'''

    decoder_cfg_file = 'config/decoder/attention_visualizer.cfg'

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(FLAGS.expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(os.path.join(FLAGS.expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_nnet_cfg = configparser.ConfigParser()
    parsed_nnet_cfg.read(os.path.join(FLAGS.expdir, 'model', 'asr.cfg'))
    nnet_cfg = dict(parsed_nnet_cfg.items('asr'))

    #read the decoder config file
    if decoder_cfg_file is None:
        decoder_cfg_file = os.path.join(FLAGS.expdir, 'model', 'decoder.cfg')
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(decoder_cfg_file)
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    #create a feature reader
    featdir = os.path.join(database_cfg['test_dir'], feat_cfg['name'])

    with open(os.path.join(featdir, 'maxlength'), 'r') as fid:
        max_length = int(fid.read())

    reader = feature_reader.FeatureReader(
        scpfile=os.path.join(featdir, 'feats.scp'),
        cmvnfile=os.path.join(featdir, 'cmvn.scp'),
        utt2spkfile=os.path.join(featdir, 'utt2spk'),
        max_length=max_length)

    #read the feature dimension
    with open(os.path.join(database_cfg['train_dir'], feat_cfg['name'], 'dim'),
              'r') as fid:

        input_dim = int(fid.read())

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)

    #create the classifier
    classifier = asr_factory.factory(conf=nnet_cfg,
                                     output_dim=coder.num_labels)

    #create a decoder
    graph = tf.Graph()
    with graph.as_default():
        decoder = decoder_factory.factory(conf=decoder_cfg,
                                          classifier=classifier,
                                          input_dim=input_dim,
                                          max_input_length=reader.max_length,
                                          coder=coder,
                                          expdir=FLAGS.expdir)

        saver = tf.train.Saver(tf.trainable_variables())

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  #pylint: disable=E1101
    config.allow_soft_placement = True

    with tf.Session(graph=graph, config=config) as sess:
        #load the model
        saver.restore(sess, os.path.join(FLAGS.expdir, 'model',
                                         'network.ckpt'))

        #decode with te neural net
        decoded = decoder.decode(reader, sess)

    #the path to the text file
    textfile = os.path.join(database_cfg['test_dir'], 'targets')

    #read all the reference transcriptions
    with open(textfile) as fid:
        lines = fid.readlines()

    references = dict()
    for line in lines:
        splitline = line.strip().split(' ')
        references[splitline[0]] = coder.encode(' '.join(splitline[1:]))

    #compute the character error rate
    score = decoder.score(decoded, references)

    print 'score: %f' % score

    #write the resulting beams to disk
    decodedir = os.path.join(FLAGS.expdir, 'decoded')
    if not os.path.isdir(decodedir):
        os.makedirs(decodedir)
    for utt in decoded:
        with open(os.path.join(decodedir, utt), 'w') as fid:
            for hypothesis in decoded[utt]:
                fid.write('%f\t%s\n' % (hypothesis[0], hypothesis[1]))
Ejemplo n.º 3
0
def train_asr(clusterfile, job_name, task_index, ssh_command, expdir):
    ''' does everything for asr training

    Args:
        clusterfile: the file where all the machines in the cluster are
            specified if None, local training will be done
        job_name: one of ps or worker in the case of distributed training
        task_index: the task index in this job
        ssh_command: the command to use for ssh, if 'None' no tunnel will be
            created
        expdir: the experiments directory
    '''

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(os.path.join(expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_nnet_cfg = configparser.ConfigParser()
    parsed_nnet_cfg.read(os.path.join(expdir, 'model', 'asr.cfg'))
    nnet_cfg = dict(parsed_nnet_cfg.items('asr'))

    #read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(os.path.join(expdir, 'trainer.cfg'))
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    #read the decoder config file
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(os.path.join(expdir, 'model', 'decoder.cfg'))
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    #create the cluster and server
    server = create_server.create_server(clusterfile=clusterfile,
                                         job_name=job_name,
                                         task_index=task_index,
                                         expdir=expdir,
                                         ssh_command=ssh_command)

    #the ps should just wait
    if job_name == 'ps':
        server.join()

    featdir = os.path.join(database_cfg['train_dir'], feat_cfg['name'])

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)

    #create a feature reader for the training data
    with open(featdir + '/maxlength', 'r') as fid:
        max_length = int(fid.read())

    featreader = feature_reader.FeatureReader(scpfile=featdir +
                                              '/feats_shuffled.scp',
                                              cmvnfile=featdir + '/cmvn.scp',
                                              utt2spkfile=featdir + '/utt2spk',
                                              max_length=max_length)

    #read the feature dimension
    with open(featdir + '/dim', 'r') as fid:
        input_dim = int(fid.read())

    #the path to the text file
    textfile = os.path.join(database_cfg['train_dir'], 'targets')

    #create a batch dispenser for the training data
    dispenser = batchdispenser.AsrBatchDispenser(
        feature_reader=featreader,
        target_coder=coder,
        size=int(trainer_cfg['batch_size']),
        target_path=textfile)

    #create a reader for the validation data
    if 'dev_data' in database_cfg:
        featdir = database_cfg['dev_dir'] + '/' + feat_cfg['name']

        with open(featdir + '/maxlength', 'r') as fid:
            max_length = int(fid.read())

        val_reader = feature_reader.FeatureReader(
            scpfile=featdir + '/feats.scp',
            cmvnfile=featdir + '/cmvn.scp',
            utt2spkfile=featdir + '/utt2spk',
            max_length=max_length)

        textfile = os.path.join(database_cfg['dev_dir'], 'targets')

        #read the validation targets
        with open(textfile) as fid:
            lines = fid.readlines()

        val_targets = dict()
        for line in lines:
            splitline = line.strip().split(' ')
            val_targets[splitline[0]] = ' '.join(splitline[1:])

    else:
        if int(trainer_cfg['valid_utt']) > 0:
            val_dispenser = dispenser.split(int(trainer_cfg['valid_utt']))
            val_reader = val_dispenser.feature_reader
            val_targets = val_dispenser.target_dict
        else:
            val_reader = None
            val_targets = None

    #encode the validation targets
    if val_targets is not None:
        for utt in val_targets:
            val_targets[utt] = dispenser.target_coder.encode(val_targets[utt])

    #create the classifier
    classifier = asr_factory.factory(conf=nnet_cfg,
                                     output_dim=coder.num_labels)

    #create the callable for the decoder
    decoder = partial(decoder_factory.factory,
                      conf=decoder_cfg,
                      classifier=classifier,
                      input_dim=input_dim,
                      max_input_length=val_reader.max_length,
                      coder=coder,
                      expdir=expdir)

    #create the trainer
    tr = trainer_factory.factory(conf=trainer_cfg,
                                 decoder=decoder,
                                 classifier=classifier,
                                 input_dim=input_dim,
                                 dispenser=dispenser,
                                 val_reader=val_reader,
                                 val_targets=val_targets,
                                 expdir=expdir,
                                 server=server,
                                 task_index=task_index)

    print 'starting training'

    #train the classifier
    tr.train()
Ejemplo n.º 4
0
def main(_):
    '''does everything for testing of pure reconstruction with the simple
     loss function'''

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(FLAGS.expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(os.path.join(FLAGS.expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_nnet_cfg = configparser.ConfigParser()
    parsed_nnet_cfg.read(os.path.join(FLAGS.expdir, 'model', 'asr.cfg'))
    nnet_cfg = dict(parsed_nnet_cfg.items('asr'))

    # read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(os.path.join(FLAGS.expdir, 'trainer.cfg'))
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    # check on what features the reconstruction is made
    if 'reconstruction_features' in trainer_cfg:
        if trainer_cfg['reconstruction_features'] == 'audio_samples':
            audio_used = True
        else:
            audio_used = False
    else:
        raise Exception(
            'no reconstruction features specified, something wrong')

    #read the quantization config file if necessary
    if audio_used:
        parsed_quant_cfg = configparser.ConfigParser()
        parsed_quant_cfg.read(
            os.path.join(FLAGS.expdir, 'model', 'quantization.cfg'))
        quant_cfg = dict(parsed_quant_cfg.items('features'))

    #create a feature reader
    featdir = os.path.join(database_cfg['test_dir'], feat_cfg['name'])

    with open(os.path.join(featdir, 'maxlength'), 'r') as fid:
        max_length_feat = int(fid.read())

    feat_reader = feature_reader.FeatureReader(
        scpfile=os.path.join(featdir, 'feats.scp'),
        cmvnfile=os.path.join(featdir, 'cmvn.scp'),
        utt2spkfile=os.path.join(featdir, 'utt2spk'),
        max_length=max_length_feat)

    #create an audio sample reader if necessary
    if audio_used:
        audiodir = os.path.join(database_cfg['test_dir'], quant_cfg['name'])

        with open(os.path.join(audiodir, 'maxlength'), 'r') as fid:
            max_length_audio = int(fid.read())

        audio_reader = feature_reader.FeatureReader(
            scpfile=os.path.join(audiodir, 'feats.scp'),
            cmvnfile=None,
            utt2spkfile=None,
            max_length=max_length_audio)

    #check number of test examples
    number_examples = feat_reader.num_utt

    # set a batch_size to determine how many test examples are
    # processed in each steps
    # this doesn't really matter, only for memory issues
    # take the same one as used in training
    batch_size = int(trainer_cfg['batch_size'])

    #create a ndarray of all of the features
    _, features, _ = feat_reader.get_utt()
    features = features.reshape(1, -1, features.shape[1])
    features_lengths = features.shape[1] * np.ones([1], dtype=np.int32)
    features = np.concatenate([
        features,
        np.zeros([
            features.shape[0], max_length_feat - features.shape[1],
            features.shape[2]
        ])
    ], 1)
    looped = False
    while not looped:
        _, temp, looped = feat_reader.get_utt()
        temp = temp.reshape(1, -1, temp.shape[1])
        features_lengths = np.concatenate(
            [features_lengths, temp.shape[1] * np.ones([1], dtype=np.int32)],
            0)
        temp = np.concatenate([
            temp,
            np.zeros([
                temp.shape[0], max_length_feat - temp.shape[1], temp.shape[2]
            ])
        ], 1)
        features = np.concatenate([features, temp], 0)

    #create a ndarray of all of the targets
    if audio_used:
        _, audio, _ = audio_reader.get_utt()
        audio = audio.reshape(1, -1, audio.shape[1])
        audio_lengths = audio.shape[1] * np.ones([1], dtype=np.int32)
        audio = np.concatenate([
            audio,
            np.zeros([
                audio.shape[0], max_length_audio - audio.shape[1],
                audio.shape[2]
            ])
        ], 1)
        looped = False
        while not looped:
            _, temp, looped = audio_reader.get_utt()
            temp = temp.reshape(1, -1, temp.shape[1])
            audio_lengths = np.concatenate([audio_lengths,
                                            temp.shape[1]*np.ones([1], \
                                                dtype=np.int32)], 0)
            temp = np.concatenate([
                temp,
                np.zeros([
                    temp.shape[0], max_length_audio - temp.shape[1],
                    temp.shape[2]
                ])
            ], 1)
            audio = np.concatenate([audio, temp], 0)

    # store dimensions
        max_audio_length = audio.shape[1]

    else:
        audio = np.zeros([number_examples, 1, 1])
        max_audio_length = 1
        audio_lengths = np.ones([number_examples])

    # store dimensions
    max_feature_length = features.shape[1]
    feature_dim = features.shape[2]

    #create a graph
    graph = tf.Graph()

    with graph.as_default():
        #create the classifier
        if audio_used:
            outputdim = int(quant_cfg['quant_levels'])
        else:
            outputdim = feature_dim
        classifier = asr_factory.factory(conf=nnet_cfg,
                                         output_dim=(1, outputdim))

        # create placeholders for reconstruction and features
        features_ph = tf.placeholder(
            tf.float32,
            shape=[batch_size, max_feature_length, feature_dim],
            name='features')

        audio_ph = tf.placeholder(tf.int32,
                                  shape=[batch_size, max_audio_length, 1],
                                  name='audio')

        audio_lengths_ph = tf.placeholder(tf.int32,
                                          shape=[batch_size],
                                          name='audio_lenght')

        feature_lengths_ph = tf.placeholder(tf.int32,
                                            shape=[batch_size],
                                            name='feat_lenght')

        # decide what to give as targets
        if audio_used:
            rec_ph = audio_ph
            rec_l_ph = audio_lengths_ph
        else:
            rec_ph = features_ph
            rec_l_ph = audio_lengths_ph

        #create the logits for reconstructed audio samples
        logits, logits_lengths = classifier(
            inputs=features_ph,
            input_seq_length=feature_lengths_ph,
            targets=(None, rec_ph),
            target_seq_length=(None, rec_l_ph),
            is_training=False)

        #compute the loss score
        score = compute_loss((None, rec_ph), logits, logits_lengths,
                             (None, rec_l_ph), audio_used)

        saver = tf.train.Saver(tf.trainable_variables())

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  #pylint: disable=E1101
    config.allow_soft_placement = True

    with tf.Session(graph=graph, config=config) as sess:

        #create a saver and load the model
        saver.restore(sess, os.path.join(FLAGS.expdir, 'model',
                                         'network.ckpt'))

        all_processed = False
        number_batch = 0
        total_elements = 0
        avrg_loss = 0.0

        total_steps = int(np.ceil(number_examples / batch_size))

        # process the loss on the test set batch by batch
        while not all_processed:
            # put a part of the features and audio samples in a batch
            start = number_batch * batch_size
            end = (number_batch + 1) * batch_size
            if end >= number_examples:
                end = number_examples
                all_processed = True
            part_features = features[start:end, :, :]
            part_features_lengths = features_lengths[start:end]
            part_audio = audio[start:end, :, :]
            part_audio_lengths = audio_lengths[start:end]

            # pad with zeros if the last batch isn't completely filled
            if all_processed:
                elements_last_batch = end - start
                to_add = batch_size - elements_last_batch
                part_features = np.concatenate([
                    part_features,
                    np.zeros([to_add, max_feature_length, feature_dim])
                ], 0)
                part_features_lengths = np.concatenate([
                    part_features_lengths,
                    np.zeros([to_add], dtype=np.int32)
                ], 0)
                part_audio = np.concatenate([
                    part_audio,
                    np.zeros([to_add, max_audio_length, 1], dtype=np.int32)
                ], 0)
                part_audio_lengths = np.concatenate(
                    [part_audio_lengths,
                     np.zeros([to_add], dtype=np.int32)], 0)

            # number of elements in the current batch
            numel = end - start

            # compute loss on this batch
            loss = sess.run(score,
                            feed_dict={
                                features_ph: part_features,
                                audio_ph: part_audio,
                                feature_lengths_ph: part_features_lengths,
                                audio_lengths_ph: part_audio_lengths
                            })

            # update the average loss with the result of loss on current batch
            avrg_loss = ((total_elements * avrg_loss + numel * loss) /
                         (numel + total_elements))
            total_elements += numel

            number_batch = number_batch + 1

            # print some info about how we're proceeding
            print 'Computing loss on test set: step %d of %d' \
                %(number_batch, total_steps)

        #test for correctness
        if not total_elements == number_examples:
            raise Exception(
                'something went wrong in loop where test loss is calculated')

    # print eventual result
    print '========================================'
    print 'The loss on the test set: %f' % avrg_loss
    print '========================================'
Ejemplo n.º 5
0
def train_asr(clusterfile,
              job_name,
              task_index,
              ssh_command,
              expdir):

    ''' does everything for asr training
    Args:
        clusterfile: the file where all the machines in the cluster are
            specified if None, local training will be done
        job_name: one of ps or worker in the case of distributed training
        task_index: the task index in this job
        ssh_command: the command to use for ssh, if 'None' no tunnel will be
            created
        expdir: the experiments directory
    '''

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(os.path.join(expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_nnet_cfg = configparser.ConfigParser()
    parsed_nnet_cfg.read(os.path.join(expdir, 'model', 'asr.cfg'))
    nnet_cfg = dict(parsed_nnet_cfg.items('asr'))

    #read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(os.path.join(expdir, 'trainer.cfg'))
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    #read the decoder config file
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(os.path.join(expdir, 'model', 'decoder.cfg'))
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    #make distinction between three implemented different kind of training forms
    if database_cfg['train_mode'] == 'supervised':
        nonsupervised = False
    elif database_cfg['train_mode'] == 'nonsupervised' or\
            database_cfg['train_mode'] == 'semisupervised':
        nonsupervised = True
    else:
        raise Exception('Wrong kind of training mode')

    #when (partly) nonsupervised, what features are used for the reconstruction
    #currently two possible options implemented
    if nonsupervised:
        if trainer_cfg['reconstruction_features'] == 'audio_samples':
            audio_used = True
        elif trainer_cfg['reconstruction_features'] == 'input_features':
            audio_used = False
        else:
            raise Exception(
                'Unknown specification for the reconstruction features')

    #read the quant config file if nonsupervised training and samples used
    if nonsupervised:
        if audio_used:
            parsed_quant_cfg = configparser.ConfigParser()
            parsed_quant_cfg.read(os.path.join(expdir,
                                               'model', 'quantization.cfg'))
            quant_cfg = dict(parsed_quant_cfg.items('features'))

    #based on the other settings, compute and overwrite samples_per_hlfeature
    #and unpredictable_samples in the classifier config dictionary
    if nonsupervised:
        if audio_used:
            rate_after_quant = int(quant_cfg['quant_rate'])
            win_lenght = float(feat_cfg['winlen'])
            win_shift = float(feat_cfg['winstep'])
            samples_one_window = int(win_lenght*rate_after_quant)
            samples_one_shift = int(win_shift*rate_after_quant)
            #### THIS IS ONLY RELEVANT WHEN USING A LISTENER WITH PYRAM STRUCT
            # and this line should be adapted otherwise
            time_compression = 2**int(nnet_cfg['listener_numlayers'])
            #store values in config dictionary
            nnet_cfg['samples_per_hlfeature'] = samples_one_shift\
                                        *time_compression
            nnet_cfg['unpredictable_samples'] = (samples_one_window+\
                                    (time_compression-1)\
                        *samples_one_shift)-nnet_cfg['samples_per_hlfeature']


    #create the cluster and server
    server = create_server.create_server(
        clusterfile=clusterfile,
        job_name=job_name,
        task_index=task_index,
        expdir=expdir,
        ssh_command=ssh_command)

    #the ps should just wait
    if job_name == 'ps':
        server.join()

    # path to where the training samples are stored
    featdir = os.path.join(database_cfg['train_dir'], feat_cfg['name'])

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)

    #create a feature reader for the training data
    with open(featdir + '/maxlength', 'r') as fid:
        max_length = int(fid.read())

    featreader = feature_reader.FeatureReader(
        scpfile=featdir + '/feats_shuffled.scp',
        cmvnfile=featdir + '/cmvn.scp',
        utt2spkfile=featdir + '/utt2spk',
        max_length=max_length)

    #read the feature dimension
    with open(featdir + '/dim', 'r') as fid:
        input_dim = int(fid.read())

    #the path to the text file
    textfile = os.path.join(database_cfg['train_dir'], 'targets')


    # If nonsupervised and audio used, we also need to read samples
    # these can be done with a second feature reader
    if nonsupervised:
        if audio_used:
            featdir2 = os.path.join(database_cfg['train_dir'],
                                    quant_cfg['name'])

            with open(featdir2 + '/maxlength', 'r') as fid:
                max_length_audio = int(fid.read())

            audioreader = feature_reader.FeatureReader(
                scpfile=featdir2 + '/feats_shuffled.scp',
                cmvnfile=None,
                utt2spkfile=None,
                max_length=max_length_audio)

    ## create a batch dispenser, depending on which situation we're in
    if not nonsupervised:
    # in the normal supervised training mode, regular dispenser is needed
        if 'las_ignoring_mode' in trainer_cfg:
            if trainer_cfg['las_ignoring_mode'] == 'True':
            # if we ignore unlabeled examples
                dispenser = batchdispenser.AsrTextBatchDispenser(
                    feature_reader=featreader,
                    target_coder=coder,
                    size=int(trainer_cfg['batch_size']),
                    target_path=textfile)
            elif trainer_cfg['las_ignoring_mode'] == 'False':
            # if we choose to process the unlabeled examples
                if 'fixed_ratio' in trainer_cfg:
                    if trainer_cfg['fixed_ratio'] == 'True':
                    # if we choose to process with batches with fixed
                    # labeled/unlabeled ratio
                        dispenser = \
                            batchdispenser.AsrTextBatchDispenserAltFixRatio(
                                feature_reader=featreader,
                                target_coder=coder,
                                size=int(trainer_cfg['batch_size']),
                                target_path=textfile,
                                percentage_unlabeled=1-float(
                                    database_cfg['part_labeled']))
                    elif trainer_cfg['fixed_ratio'] == 'False':
                        # if the fixed ratio is not used
                        dispenser = batchdispenser.AsrTextBatchDispenserAlt(
                            feature_reader=featreader,
                            target_coder=coder,
                            size=int(trainer_cfg['batch_size']),
                            target_path=textfile)
                    else:
                        raise Exception('wrong information in fixed_ratio var')
                else:
                # if fixed ratio is not specified, we choose to do without it
                    dispenser = batchdispenser.AsrTextBatchDispenserAlt(
                        feature_reader=featreader,
                        target_coder=coder,
                        size=int(trainer_cfg['batch_size']),
                        target_path=textfile)
            else:
                raise Exception('wrong information in LAS_ignoring_mode var')
        else:
        # if no specification is made about the ignoring, ignore the unlabeled
            dispenser = batchdispenser.AsrTextBatchDispenser(
                feature_reader=featreader,
                target_coder=coder,
                size=int(trainer_cfg['batch_size']),
                target_path=textfile)
    else:
        # when doing (partly) nonsupervised extra reconstruction features needed
        if audio_used:
        # when the audio is the reconstruction feature
            if 'fixed_ratio' in trainer_cfg:
                if trainer_cfg['fixed_ratio'] == 'True':
                    # if specified to work with fixed lab/unlab ratio batches
                    dispenser = \
                        batchdispenser.AsrTextAndAudioBatchDispenserFixRatio(
                            feature_reader=featreader,
                            audio_reader=audioreader,
                            target_coder=coder,
                            size=int(trainer_cfg['batch_size']),
                            target_path=textfile,
                            percentage_unlabeled=1-float(
                                database_cfg['part_labeled']))
                elif trainer_cfg['fixed_ratio'] == 'False':
                # if specified to not use the fixed ratio
                    dispenser = batchdispenser.AsrTextAndAudioBatchDispenser(
                        feature_reader=featreader,
                        audio_reader=audioreader,
                        target_coder=coder,
                        size=int(trainer_cfg['batch_size']),
                        target_path=textfile)
                else:
                    raise Exception('wrong information in fixed_ratio var')
            else:
            # without specification, suppose no fixed ratio batches
                dispenser = batchdispenser.AsrTextAndAudioBatchDispenser(
                    feature_reader=featreader,
                    audio_reader=audioreader,
                    target_coder=coder,
                    size=int(trainer_cfg['batch_size']),
                    target_path=textfile)
        else:
        # if no audio is used, the input features are used
            if 'fixed_ratio' in trainer_cfg:
                if trainer_cfg['fixed_ratio'] == 'True':
                # if specified to work with fixed labeled/unlabled ratio batches
                    dispenser = \
                        batchdispenser.AsrTextAndFeatBatchDispenserFixRatio(
                            feature_reader=featreader,
                            target_coder=coder,
                            size=int(trainer_cfg['batch_size']),
                            target_path=textfile,
                            percentage_unlabeled=1-float(
                                database_cfg['part_labeled']))
                elif trainer_cfg['fixed_ratio'] == 'False':
                # if specified to not use the fixed ratio
                    dispenser = batchdispenser.AsrTextAndFeatBatchDispenser(
                        feature_reader=featreader,
                        target_coder=coder,
                        size=int(trainer_cfg['batch_size']),
                        target_path=textfile)
                else:
                    raise Exception('wrong information in fixed_ratio var')
            else:
            # without specification, suppose no fixed ratio batches
                dispenser = batchdispenser.AsrTextAndFeatBatchDispenser(
                    feature_reader=featreader,
                    target_coder=coder,
                    size=int(trainer_cfg['batch_size']),
                    target_path=textfile)


    # read validation data. If there are text targets, they are only important
    # for the validation data. If only nonsupervised, we must validate on the
    # reconstructed features
    if 'dev_data' in database_cfg:
        # create a reader for the validation inputs
        featdir = database_cfg['dev_dir'] + '/' +  feat_cfg['name']

        with open(featdir + '/maxlength', 'r') as fid:
            max_length = int(fid.read())

        val_reader = feature_reader.FeatureReader(
            scpfile=featdir + '/feats.scp',
            cmvnfile=featdir + '/cmvn.scp',
            utt2spkfile=featdir + '/utt2spk',
            max_length=max_length)

        textfile = os.path.join(database_cfg['dev_dir'], 'targets')

        #read the validation text targets
        with open(textfile) as fid:
            lines = fid.readlines()

            val_text_targets = dict()
            for line in lines:
                splitline = line.strip().split(' ')
                val_text_targets[splitline[0]] = ' '.join(splitline[1:])

        if nonsupervised:
        #also store the reconstruction targets
            val_rec_targets = dict()
            if audio_used:
                audiodir = database_cfg['dev_dir'] + '/' +  quant_cfg['name']
                with open(audiodir + '/maxlength', 'r') as fid:
                    max_length_audio = int(fid.read())
                val_audio_reader = feature_reader.FeatureReader(
                    scpfile=audiodir + '/feats.scp',
                    cmvnfile=None,
                    utt2spkfile=audiodir + '/utt2spk',
                    max_length=max_length_audio)
                for _ in range(val_audio_reader.num_utt):
                    utt_id, audio, _ = val_audio_reader.get_utt()
                    val_rec_targets[utt_id] = audio
            else: #input features are used
                for _ in range(val_reader.num_utt):
                    utt_id, feat, _ = val_reader.get_utt()
                    val_rec_targets[utt_id] = feat
        else:
            with open(textfile) as fid:
                lines = fid.readlines()

                val_rec_targets = dict()
                for line in lines:
                    splitline = line.strip().split(' ')
                    val_rec_targets[splitline[0]] = None

        val_targets = dict()
        for utt_id in val_text_targets:
            val_targets[utt_id] = (val_text_targets[utt_id],
                                   val_rec_targets[utt_id])

    else:
        if int(trainer_cfg['valid_utt']) > 0:
            val_dispenser = dispenser.split(int(trainer_cfg['valid_utt']))
            val_reader = val_dispenser.feature_reader
            val_targets = val_dispenser.target_dict
        else:
            val_reader = None
            val_targets = None

    #encode the validation targets
    if val_targets is not None:
        for utt in val_targets:
            val_targets[utt] = (dispenser.target_coder.encode(
                val_targets[utt][0]), val_targets[utt][1])


    #create the classifier
    if nonsupervised:
        if audio_used:
            output_dim_second_el = int(quant_cfg['quant_levels'])
        else: # input features used
            output_dim_second_el = input_dim
    else: # only supervised training
        output_dim_second_el = None

    classifier = asr_factory.factory(
        conf=nnet_cfg,
        output_dim=(coder.num_labels, output_dim_second_el))

    #create the callable for the decoder
    decoder = partial(
        decoder_factory.factory,
        conf=decoder_cfg,
        classifier=classifier,
        input_dim=input_dim,
        max_input_length=val_reader.max_length,
        coder=coder,
        expdir=expdir)

    #create the trainer
    if nonsupervised:
        if audio_used:
            reconstruction_dim = 1
        else:
            reconstruction_dim = input_dim
    else:
        reconstruction_dim = 1

    tr = trainer_factory.factory(
        conf=trainer_cfg,
        decoder=decoder,
        classifier=classifier,
        input_dim=input_dim,
        reconstruction_dim=reconstruction_dim,
        dispenser=dispenser,
        val_reader=val_reader,
        val_targets=val_targets,
        expdir=expdir,
        server=server,
        task_index=task_index)

    print 'starting training'

    #train the classifier
    tr.train()
Ejemplo n.º 6
0
def main(_):
    '''does everything for testing'''

    #decoder_cfg_file = 'config/decoder/attention_visualizer.cfg'
    decoder_cfg_file = None

    #read the database config file
    parsed_database_cfg = configparser.ConfigParser()
    parsed_database_cfg.read(os.path.join(FLAGS.expdir, 'database.cfg'))
    database_cfg = dict(parsed_database_cfg.items('database'))

    # check the training mode
    if database_cfg['train_mode'] == 'supervised':
        nonsupervised = False
    elif database_cfg['train_mode'] == 'semisupervised':
        nonsupervised = True
    elif database_cfg['train_mode'] == 'nonsupervised':
        raise Exception('Purely nonsupervised models should be tested with \
                            the test_reconstruction file.')
    else:
        raise Exception('Wrong kind of training mode')

    #read the features config file
    parsed_feat_cfg = configparser.ConfigParser()
    parsed_feat_cfg.read(os.path.join(FLAGS.expdir, 'model', 'features.cfg'))
    feat_cfg = dict(parsed_feat_cfg.items('features'))

    #read the asr config file
    parsed_nnet_cfg = configparser.ConfigParser()
    parsed_nnet_cfg.read(os.path.join(FLAGS.expdir, 'model', 'asr.cfg'))
    nnet_cfg = dict(parsed_nnet_cfg.items('asr'))

    # read the trainer config file
    parsed_trainer_cfg = configparser.ConfigParser()
    parsed_trainer_cfg.read(os.path.join(FLAGS.expdir, 'trainer.cfg'))
    trainer_cfg = dict(parsed_trainer_cfg.items('trainer'))

    #read the decoder config file
    if decoder_cfg_file is None:
        decoder_cfg_file = os.path.join(FLAGS.expdir, 'model', 'decoder.cfg')
    parsed_decoder_cfg = configparser.ConfigParser()
    parsed_decoder_cfg.read(decoder_cfg_file)
    decoder_cfg = dict(parsed_decoder_cfg.items('decoder'))

    # if (partly) nonsupervised, check what kind reconstruction features used
    # for now two options are implemented
    if nonsupervised:
        if trainer_cfg['reconstruction_features'] == 'audio_samples':
            audio_used = True
        else:
            audio_used = False

    if nonsupervised:
        if audio_used:
            #read the quantization config file if necessary
            parsed_quant_cfg = configparser.ConfigParser()
            parsed_quant_cfg.read(os.path.join(FLAGS.expdir, 'model',
                                               'quantization.cfg'))
            quant_cfg = dict(parsed_quant_cfg.items('features'))

    #create a feature reader
    featdir = os.path.join(database_cfg['test_dir'], feat_cfg['name'])

    with open(os.path.join(featdir, 'maxlength'), 'r') as fid:
        max_length = int(fid.read())

    reader = feature_reader.FeatureReader(
        scpfile=os.path.join(featdir, 'feats.scp'),
        cmvnfile=os.path.join(featdir, 'cmvn.scp'),
        utt2spkfile=os.path.join(featdir, 'utt2spk'),
        max_length=max_length)

    #read the feature dimension
    with open(
        os.path.join(database_cfg['train_dir'], feat_cfg['name'],
                     'dim'),
        'r') as fid:

        input_dim = int(fid.read())

    #create the coder
    with open(os.path.join(database_cfg['train_dir'], 'alphabet')) as fid:
        alphabet = fid.read().split(' ')
    coder = target_coder.TargetCoder(alphabet)

    #create the classifier
    if not nonsupervised:
        outputdim2 = 1
    else:
        if audio_used:
            outputdim2 = int(quant_cfg['quant_levels'])
        else: #then input features used
            outputdim2 = input_dim

    classifier = asr_factory.factory(
        conf=nnet_cfg,
        output_dim=(coder.num_labels, outputdim2))

    #create a decoder
    graph = tf.Graph()
    with graph.as_default():
        decoder = decoder_factory.factory(
            conf=decoder_cfg,
            classifier=classifier,
            input_dim=input_dim,
            max_input_length=reader.max_length,
            coder=coder,
            expdir=FLAGS.expdir)

        saver = tf.train.Saver(tf.trainable_variables())


    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True #pylint: disable=E1101
    config.allow_soft_placement = True

    with tf.Session(graph=graph, config=config) as sess:
        #load the model
        saver.restore(sess, os.path.join(FLAGS.expdir, 'model', 'network.ckpt'))

        #decode with te neural net
        decoded = decoder.decode(reader, sess)

    #the path to the text file
    textfile = database_cfg['testtext']

    #read all the reference transcriptions
    with open(textfile) as fid:
        lines = fid.readlines()

    references = dict()
    for line in lines:
        splitline = line.strip().split(' ')
        references[splitline[0]] = coder.encode(' '.join(splitline[1:]))

    #compute the character error rate
    score = decoder.score(decoded, references)

    print 'score: %f' % score

    #write the resulting beams to disk
    decodedir = os.path.join(FLAGS.expdir, 'decoded')
    if not os.path.isdir(decodedir):
        os.makedirs(decodedir)
    for utt in decoded:
        with open(os.path.join(decodedir, utt), 'w') as fid:
            for hypothesis in decoded[utt]:
                fid.write('%f\t%s\n' % (hypothesis[0], hypothesis[1]))