Exemple #1
0
def train():
    # Create training and validation datasets
    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               cache_path=FLAGS.feature_cache,
                               train_phase=True)

    iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
                                                 tfv1.data.get_output_shapes(train_set),
                                                 output_classes=tfv1.data.get_output_classes(train_set))

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_csvs = FLAGS.dev_files.split(',')
        dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
        dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

    # Dropout
    dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {
        rate: 0. for rate in dropout_rates
    }

    # Building the graph
    optimizer = create_optimizer()
Exemple #2
0
def evaluate(test_csvs, create_model, try_loading):
    scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
                    FLAGS.lm_binary_path, FLAGS.lm_trie_path,
                    Config.alphabet)

    test_csvs = FLAGS.test_files.split(',')
    test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
    iterator = tf.data.Iterator.from_structure(test_sets[0].output_types,
                                               test_sets[0].output_shapes,
                                               output_classes=test_sets[0].output_classes)
    test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]

    (batch_x, batch_x_len), batch_y = iterator.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))

    loss = tf.nn.ctc_loss(labels=batch_y,
                          inputs=logits,
                          sequence_length=batch_x_len)

    tf.train.get_or_create_global_step()

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except NotImplementedError:
        num_processes = 1

    # Create a saver using variables from the above newly created graph
    saver = tf.train.Saver()

    with tf.Session(config=Config.session_config) as session:
        # Restore variables from training checkpoint
        loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
        if not loaded:
            loaded = try_loading(session, saver, 'checkpoint', 'most recent')
        if not loaded:
            log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
            exit(1)

        def run_test(init_op, dataset):
            logitses = []
            losses = []
            seq_lengths = []
            ground_truths = []

            bar = create_progressbar(prefix='Computing acoustic model predictions | ',
                                     widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
            log_progress('Computing acoustic model predictions...')

            step_count = 0

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # First pass, compute losses and transposed logits for decoding
            while True:
                try:
                    logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
                except tf.errors.OutOfRangeError:
                    break

                step_count += 1
                bar.update(step_count)

                logitses.append(logits)
                losses.extend(loss_)
                seq_lengths.append(lengths)
                ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))

            bar.finish()

            predictions = []

            bar = create_progressbar(max_value=step_count,
                                     prefix='Decoding predictions | ').start()
            log_progress('Decoding predictions...')

            # Second pass, decode logits and compute WER and edit distance metrics
            for logits, seq_length in bar(zip(logitses, seq_lengths)):
                decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
                                                        num_processes=num_processes, scorer=scorer)
                predictions.extend(d[0][1] for d in decoded)

            distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]

            wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
            mean_loss = np.mean(losses)

            # Take only the first report_count items
            report_samples = itertools.islice(samples, FLAGS.report_count)

            print('Test on %s - WER: %f, CER: %f, loss: %f' %
                  (dataset, wer, cer, mean_loss))
            print('-' * 80)
            for sample in report_samples:
                print('WER: %f, CER: %f, loss: %f' %
                      (sample.wer, sample.distance, sample.loss))
                print(' - src: "%s"' % sample.src)
                print(' - res: "%s"' % sample.res)
                print('-' * 80)

            return samples

        samples = []
        for csv, init_op in zip(test_csvs, test_init_ops):
            print('Testing model on {}'.format(csv))
            samples.extend(run_test(init_op, dataset=csv))
        return samples
Exemple #3
0
def evaluate(test_csvs, create_model, try_loading):
    if FLAGS.lm_binary_path:
        scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path,
                        FLAGS.lm_trie_path, Config.alphabet)
    else:
        scorer = None

    test_csvs = FLAGS.test_files.split(',')
    test_sets = [
        create_dataset([csv],
                       batch_size=FLAGS.test_batch_size,
                       train_phase=False) for csv in test_csvs
    ]
    iterator = tfv1.data.Iterator.from_structure(
        tfv1.data.get_output_types(test_sets[0]),
        tfv1.data.get_output_shapes(test_sets[0]),
        output_classes=tfv1.data.get_output_classes(test_sets[0]))
    test_init_ops = [
        iterator.make_initializer(test_set) for test_set in test_sets
    ]

    batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             batch_size=FLAGS.test_batch_size,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))

    loss = tfv1.nn.ctc_loss(labels=batch_y,
                            inputs=logits,
                            sequence_length=batch_x_len)

    tfv1.train.get_or_create_global_step()

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except NotImplementedError:
        num_processes = 1

    # Create a saver using variables from the above newly created graph
    saver = tfv1.train.Saver()

    with tfv1.Session(config=Config.session_config) as session:
        # Restore variables from training checkpoint
        loaded = try_loading(session, saver, 'best_dev_checkpoint',
                             'best validation')
        if not loaded:
            loaded = try_loading(session, saver, 'checkpoint', 'most recent')
        if not loaded:
            log_error(
                'Checkpoint directory ({}) does not contain a valid checkpoint state.'
                .format(FLAGS.checkpoint_dir))
            exit(1)

        def run_test(init_op, dataset):
            wav_filenames = []
            losses = []
            predictions = []
            ground_truths = []

            bar = create_progressbar(prefix='Test epoch | ',
                                     widgets=[
                                         'Steps: ',
                                         progressbar.Counter(), ' | ',
                                         progressbar.Timer()
                                     ]).start()
            log_progress('Test epoch...')

            step_count = 0

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # First pass, compute losses and transposed logits for decoding
            while True:
                try:
                    batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
                        session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
                except tf.errors.OutOfRangeError:
                    break

                decoded = ctc_beam_search_decoder_batch(
                    batch_logits,
                    batch_lengths,
                    Config.alphabet,
                    FLAGS.beam_width,
                    num_processes=num_processes,
                    scorer=scorer,
                    cutoff_prob=FLAGS.cutoff_prob,
                    cutoff_top_n=FLAGS.cutoff_top_n)
                predictions.extend(d[0][1] for d in decoded)
                ground_truths.extend(
                    sparse_tensor_value_to_texts(batch_transcripts,
                                                 Config.alphabet))
                wav_filenames.extend(
                    wav_filename.decode('UTF-8')
                    for wav_filename in batch_wav_filenames)
                losses.extend(batch_loss)

                step_count += 1
                bar.update(step_count)

            bar.finish()

            wer, cer, samples = calculate_report(wav_filenames, ground_truths,
                                                 predictions, losses)
            mean_loss = np.mean(losses)

            # Take only the first report_count items
            report_samples = itertools.islice(samples, FLAGS.report_count)

            print('Test on %s - WER: %f, CER: %f, loss: %f' %
                  (dataset, wer, cer, mean_loss))
            print('-' * 80)
            for sample in report_samples:
                print('WER: %f, CER: %f, loss: %f' %
                      (sample.wer, sample.cer, sample.loss))
                print(' - wav: file://%s' % sample.wav_filename)
                print(' - src: "%s"' % sample.src)
                print(' - res: "%s"' % sample.res)
                print('-' * 80)

            return samples

        samples = []
        for csv, init_op in zip(test_csvs, test_init_ops):
            print('Testing model on {}'.format(csv))
            samples.extend(run_test(init_op, dataset=csv))
        return samples
Exemple #4
0
def train():
    do_cache_dataset = True

    # pylint: disable=too-many-boolean-expressions
    if (FLAGS.data_aug_features_multiplicative > 0 or
            FLAGS.data_aug_features_additive > 0 or
            FLAGS.augmentation_spec_dropout_keeprate < 1 or
            FLAGS.augmentation_freq_and_time_masking or
            FLAGS.augmentation_pitch_and_tempo_scaling or
            FLAGS.augmentation_speed_up_std > 0 or
            FLAGS.augmentation_sparse_warp):
        do_cache_dataset = False

    # Create training and validation datasets
    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               enable_cache=FLAGS.feature_cache and do_cache_dataset,
                               cache_path=FLAGS.feature_cache,
                               train_phase=True)

    iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
                                                 tfv1.data.get_output_shapes(train_set),
                                                 output_classes=tfv1.data.get_output_classes(train_set))

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_csvs = FLAGS.dev_files.split(',')
        dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
        dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

    # The transfer learning approach here need us to supply the layers which we
    # want to exclude from the source model.
    # Say we want to exclude all layers except for the first one, we can use this:
    #
    #    drop_source_layers=['2', '3', 'lstm', '5', '6']
    #
    # If we want to use all layers from the source model except the last one, we use this:
    #
    #    drop_source_layers=['6']
    #

    if FLAGS.load == "transfer":
        drop_source_layers = ['2', '3', 'lstm', '5', '6'][-int(FLAGS.drop_source_layers):]
    else:
        drop_source_layers=None
    
    # Dropout
    dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {
        rate: 0. for rate in dropout_rates
    }

    # Building the graph
    optimizer = create_optimizer()

    # Enable mixed precision training
    if FLAGS.automatic_mixed_precision:
        log_info('Enabling automatic mixed precision training.')
        optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)

    gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates, drop_source_layers)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)

    # global_step is automagically incremented by the optimizer
    global_step = tfv1.train.get_or_create_global_step()
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

    # Summaries
    step_summaries_op = tfv1.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
        'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')

    best_dev_saver = tfv1.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')

    # Save flags next to checkpoints
    os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)

    flags_file = os.path.join(FLAGS.checkpoint_dir, 'flags.txt')
    with open(flags_file, 'w') as fout:
        fout.write(FLAGS.flags_into_string())

    initializer = tfv1.global_variables_initializer()

    with tfv1.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        # Loading or initializing
        loaded = False

        # Initialize training from a CuDNN RNN checkpoint
        if FLAGS.cudnn_checkpoint:
            if FLAGS.use_cudnn_rnn:
                log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
                          'was specified. The --cudnn_checkpoint flag is only '
                          'needed when converting a CuDNN RNN checkpoint to '
                          'a CPU-capable graph. If your system is capable of '
                          'using CuDNN RNN, you can just specify the CuDNN RNN '
                          'checkpoint normally with --checkpoint_dir.')
                sys.exit(1)

            log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
            ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
            missing_variables = []

            # Load compatible variables from checkpoint
            for v in tfv1.global_variables():
                try:
                    v.load(ckpt.get_tensor(v.op.name), session=session)
                except tf.errors.NotFoundError:
                    missing_variables.append(v)

            # Check that the only missing variables are the Adam moment tensors
            if any('Adam' not in v.op.name for v in missing_variables):
                log_error('Tried to load a CuDNN RNN checkpoint but there were '
                          'more missing variables than just the Adam moment '
                          'tensors.')
                sys.exit(1)

            # Initialize Adam moment tensors from scratch to allow use of CuDNN
            # RNN checkpoints.
            log_info('Initializing missing Adam moment tensors.')
            init_op = tfv1.variables_initializer(missing_variables)
            session.run(init_op)
            loaded = True
			
        

        if not loaded and FLAGS.load in ['auto', 'last']:
            #tf.initialize_all_variables().run()
            tfv1.get_default_graph().finalize()			
            loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
        if not loaded and FLAGS.load in ['auto', 'best']:
            #tf.initialize_all_variables().run()
            tfv1.get_default_graph().finalize()
            loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
        if not loaded : 
            if FLAGS.load == "transfer":
                if FLAGS.source_model_checkpoint_dir:
                    print('Initializing model from', FLAGS.source_model_checkpoint_dir)
                    ckpt = tfv1.train.load_checkpoint(FLAGS.source_model_checkpoint_dir)
                    variables = list(ckpt.get_variable_to_shape_map().keys())
                    print('variable', variables)
                    print('global', tf.global_variables())				
                    # Load desired source variables
                    missing_variables2 = []				
                    for v in tf.global_variables():
                        if not any(layer in v.op.name for layer in drop_source_layers):
                            print('Loading', v.op.name)
                            try:						
                                v.load(ckpt.get_tensor(v.op.name), session=session)
                                print('OK')
                            except tf.errors.NotFoundError:
                                missing_variables2.append(v)
                                print('KO')
                            except ValueError:
                                #missing_variables2.append(v)
                                print('KO for valueError')						
                    print('missing_variables =', missing_variables2)					
                    # Initialize all variables needed for DS, but not loaded from ckpt
                    
                    init_op = tfv1.variables_initializer(
                        [v for v in tf.global_variables()
                        if any(layer in v.op.name
                                for layer in drop_source_layers)
                        ] + missing_variables2)
                    tfv1.get_default_graph().finalize()
                    session.run(init_op)
                   
			
            elif FLAGS.load in ['auto', 'init']:
                log_info('Initializing variables...')
                tfv1.get_default_graph().finalize()
                session.run(initializer)
            else:
                log_error('Unable to load %s model from specified checkpoint dir'
                        ' - consider using load option "auto" or "init".' % FLAGS.load)
                sys.exit(1)


        def run_set(set_name, epoch, init_op, dataset=None):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            # Setup progress bar
            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data['mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

            prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
            widgets = [' | ', progressbar.widgets.Timer(),
                       ' | Steps: ', progressbar.widgets.Counter(),
                       ' | ', LossWidget()]
            suffix = ' | Dataset: {}'.format(dataset) if dataset else None
            pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, problem_files, step_summary = \
                        session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
                                    feed_dict=feed_dict)
                except tf.errors.InvalidArgumentError as err:
                    if FLAGS.augmentation_sparse_warp:
                        log_info("Ignoring sparse warp error: {}".format(err))
                        continue
                    else:
                        raise
                except tf.errors.OutOfRangeError:
                    break

                if problem_files.size > 0:
                    problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
                    log_error('The following files caused an infinite (or NaN) '
                              'loss: {}'.format(','.join(problem_files)))

                total_loss += batch_loss
                step_count += 1

                pbar.update(step_count)

                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
                    checkpoint_time = time.time()

            pbar.finish()
            mean_loss = total_loss / step_count if step_count > 0 else 0.0
            return mean_loss, step_count

        log_info('STARTING Optimization')
        train_start_time = datetime.utcnow()
        best_dev_loss = float('inf')
        dev_losses = []
        try:
            for epoch in range(FLAGS.epochs):
                # Training
                log_progress('Training epoch %d...' % epoch)
                train_loss, _ = run_set('train', epoch, train_init_op)
                log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
                checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

                if FLAGS.dev_files:
                    # Validation
                    dev_loss = 0.0
                    total_steps = 0
                    for csv, init_op in zip(dev_csvs, dev_init_ops):
                        log_progress('Validating epoch %d on %s...' % (epoch, csv))
                        set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
                        dev_loss += set_loss * steps
                        total_steps += steps
                        log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
                    dev_loss = dev_loss / total_steps

                    dev_losses.append(dev_loss)

                    if dev_loss < best_dev_loss:
                        best_dev_loss = dev_loss
                        save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
                        log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))

                    # Early stopping
                    if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
                        mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                        std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                        dev_losses = dev_losses[-FLAGS.es_steps:]
                        log_debug('Checking for early stopping (last %d steps) validation loss: '
                                  '%f, with standard deviation: %f and mean: %f' %
                                  (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
                        if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                           (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                            log_info('Early stop triggered as (for last %d steps) validation loss:'
                                     ' %f with standard deviation: %f and mean: %f' %
                                     (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
                            break
        except KeyboardInterrupt:
            pass
        log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
    log_debug('Session closed.')
Exemple #5
0
def train():
    # Create training and validation datasets
    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               cache_path=FLAGS.feature_cache)

    iterator = tfv1.data.Iterator.from_structure(
        tfv1.data.get_output_types(train_set),
        tfv1.data.get_output_shapes(train_set),
        output_classes=tfv1.data.get_output_classes(train_set))

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_csvs = FLAGS.dev_files.split(',')
        dev_sets = [
            create_dataset([csv], batch_size=FLAGS.dev_batch_size)
            for csv in dev_csvs
        ]
        dev_init_ops = [
            iterator.make_initializer(dev_set) for dev_set in dev_sets
        ]

    # Dropout
    dropout_rates = [
        tfv1.placeholder(tf.float32, name='dropout_{}'.format(i))
        for i in range(6)
    ]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {rate: 0. for rate in dropout_rates}

    # Building the graph
    optimizer = create_optimizer()
    gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)

    # global_step is automagically incremented by the optimizer
    global_step = tfv1.train.get_or_create_global_step()
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients,
                                                  global_step=global_step)

    # Summaries
    step_summaries_op = tfv1.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train':
        tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'),
                                max_queue=120),
        'dev':
        tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'),
                                max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
    checkpoint_filename = 'checkpoint'

    best_dev_saver = tfv1.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
    best_dev_filename = 'best_dev_checkpoint'

    initializer = tfv1.global_variables_initializer()

    with tfv1.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        # Loading or initializing
        loaded = False

        # Initialize training from a CuDNN RNN checkpoint
        if FLAGS.cudnn_checkpoint:
            if FLAGS.use_cudnn_rnn:
                log_error(
                    'Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
                    'was specified. The --cudnn_checkpoint flag is only '
                    'needed when converting a CuDNN RNN checkpoint to '
                    'a CPU-capable graph. If your system is capable of '
                    'using CuDNN RNN, you can just specify the CuDNN RNN '
                    'checkpoint normally with --checkpoint_dir.')
                exit(1)

            log_info('Converting CuDNN RNN checkpoint from {}'.format(
                FLAGS.cudnn_checkpoint))
            ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
            missing_variables = []

            # Load compatible variables from checkpoint
            for v in tfv1.global_variables():
                try:
                    v.load(ckpt.get_tensor(v.op.name), session=session)
                except tf.errors.NotFoundError:
                    missing_variables.append(v)

            # Check that the only missing variables are the Adam moment tensors
            if any('Adam' not in v.op.name for v in missing_variables):
                log_error(
                    'Tried to load a CuDNN RNN checkpoint but there were '
                    'more missing variables than just the Adam moment '
                    'tensors.')
                exit(1)

            # Initialize Adam moment tensors from scratch to allow use of CuDNN
            # RNN checkpoints.
            log_info('Initializing missing Adam moment tensors.')
            init_op = tfv1.variables_initializer(missing_variables)
            session.run(init_op)
            loaded = True

        tfv1.get_default_graph().finalize()

        if not loaded and FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session, checkpoint_saver,
                                 checkpoint_filename, 'most recent')
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session, best_dev_saver, best_dev_filename,
                                 'best validation')
        if not loaded:
            if FLAGS.load in ['auto', 'init']:
                log_info('Initializing variables...')
                session.run(initializer)
            else:
                log_error(
                    'Unable to load %s model from specified checkpoint dir'
                    ' - consider using load option "auto" or "init".' %
                    FLAGS.load)
                sys.exit(1)

        def run_set(set_name, epoch, init_op, dataset=None):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            # Setup progress bar
            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(
                        self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data[
                        'mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(
                        self, progress, data, **kwargs)

            prefix = 'Epoch {} | {:>10}'.format(
                epoch, 'Training' if is_train else 'Validation')
            widgets = [
                ' | ',
                progressbar.widgets.Timer(), ' | Steps: ',
                progressbar.widgets.Counter(), ' | ',
                LossWidget()
            ]
            suffix = ' | Dataset: {}'.format(dataset) if dataset else None
            pbar = create_progressbar(prefix=prefix,
                                      widgets=widgets,
                                      suffix=suffix).start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, step_summary = \
                        session.run([train_op, global_step, loss, step_summaries_op],
                                    feed_dict=feed_dict)
                except tf.errors.OutOfRangeError:
                    break

                total_loss += batch_loss
                step_count += 1

                pbar.update(step_count)

                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time(
                ) - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session,
                                          checkpoint_path,
                                          global_step=current_step)
                    checkpoint_time = time.time()

            pbar.finish()
            mean_loss = total_loss / step_count if step_count > 0 else 0.0
            return mean_loss, step_count

        log_info('STARTING Optimization')
        train_start_time = datetime.utcnow()
        best_dev_loss = float('inf')
        dev_losses = []
        try:
            for epoch in range(FLAGS.epochs):
                # Training
                log_progress('Training epoch %d...' % epoch)
                train_loss, _ = run_set('train', epoch, train_init_op)
                log_progress('Finished training epoch %d - loss: %f' %
                             (epoch, train_loss))
                checkpoint_saver.save(session,
                                      checkpoint_path,
                                      global_step=global_step)

                if FLAGS.dev_files:
                    # Validation
                    dev_loss = 0.0
                    total_steps = 0
                    for csv, init_op in zip(dev_csvs, dev_init_ops):
                        log_progress('Validating epoch %d on %s...' %
                                     (epoch, csv))
                        set_loss, steps = run_set('dev',
                                                  epoch,
                                                  init_op,
                                                  dataset=csv)
                        dev_loss += set_loss * steps
                        total_steps += steps
                        log_progress(
                            'Finished validating epoch %d on %s - loss: %f' %
                            (epoch, csv, set_loss))
                    dev_loss = dev_loss / total_steps

                    dev_losses.append(dev_loss)

                    if dev_loss < best_dev_loss:
                        best_dev_loss = dev_loss
                        save_path = best_dev_saver.save(
                            session,
                            best_dev_path,
                            global_step=global_step,
                            latest_filename=best_dev_filename)
                        log_info(
                            "Saved new best validating model with loss %f to: %s"
                            % (best_dev_loss, save_path))

                    # Early stopping
                    if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
                        mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                        std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                        dev_losses = dev_losses[-FLAGS.es_steps:]
                        log_debug(
                            'Checking for early stopping (last %d steps) validation loss: '
                            '%f, with standard deviation: %f and mean: %f' %
                            (FLAGS.es_steps, dev_losses[-1], std_loss,
                             mean_loss))
                        if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                           (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                            log_info(
                                'Early stop triggered as (for last %d steps) validation loss:'
                                ' %f with standard deviation: %f and mean: %f'
                                % (FLAGS.es_steps, dev_losses[-1], std_loss,
                                   mean_loss))
                            break
        except KeyboardInterrupt:
            pass
        log_info('FINISHED optimization in {}'.format(datetime.utcnow() -
                                                      train_start_time))
    log_debug('Session closed.')
Exemple #6
0
def evaluate_with_pruning(test_csvs,
                          prune_percentage,
                          random,
                          scores_file,
                          result_file,
                          verbose=True,
                          skip_lstm=False):
    '''Code originaly comes from the DeepSpeech repository (./DeepSpeech/evaluate.py).
    The code is adapted for evaluation on pruned versions of the DeepSpeech model.
    '''
    tfv1.reset_default_graph()
    if FLAGS.lm_binary_path:
        scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path,
                        FLAGS.lm_trie_path, Config.alphabet)
    else:
        scorer = None

    test_csvs = test_csvs.split(',')
    test_sets = [
        create_dataset([csv],
                       batch_size=FLAGS.test_batch_size,
                       train_phase=False) for csv in test_csvs
    ]
    iterator = tfv1.data.Iterator.from_structure(
        tfv1.data.get_output_types(test_sets[0]),
        tfv1.data.get_output_shapes(test_sets[0]),
        output_classes=tfv1.data.get_output_classes(test_sets[0]))
    test_init_ops = [
        iterator.make_initializer(test_set) for test_set in test_sets
    ]

    batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             batch_size=FLAGS.test_batch_size,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))

    loss = tfv1.nn.ctc_loss(labels=batch_y,
                            inputs=logits,
                            sequence_length=batch_x_len)

    tfv1.train.get_or_create_global_step()

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except NotImplementedError:
        num_processes = 1

    # Create a saver using variables from the above newly created graph
    saver = tfv1.train.Saver()

    with tfv1.Session(config=Config.session_config) as session:

        # Create a saver using variables from the above newly created graph
        saver = tfv1.train.Saver()

        # Restore variables from training checkpoint
        loaded = False
        if not loaded and FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session,
                                 saver,
                                 'checkpoint',
                                 'most recent',
                                 load_step=False)
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session,
                                 saver,
                                 'best_dev_checkpoint',
                                 'best validation',
                                 load_step=False)
        if not loaded:
            print('Could not load checkpoint from {}'.format(
                FLAGS.checkpoint_dir))
            sys.exit(1)

        ###### PRUNING PART ######

        if verbose:
            if not prune_percentage: print('No pruning done.')
        else:
            if verbose: print('-' * 80)
            if verbose: print('pruning with {}%...'.format(prune_percentage))
            scores_per_layer = np.load(scores_file)
            layer_masks = prune_matrices(scores_per_layer,
                                         prune_percentage=prune_percentage,
                                         random=random,
                                         verbose=verbose,
                                         skip_lstm=skip_lstm)

            n_layers_to_prune = len(layer_masks)
            i = 0
            for index, v in enumerate(tf.trainable_variables()):
                lstm_layer_name = 'cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel:0'
                if 'weights' not in v.name and v.name != lstm_layer_name:
                    continue
                if (i >= n_layers_to_prune):
                    break  # if i < total_ops, it is not yet the last layer
                # make mask into the shape of the weights
                if v.name == lstm_layer_name:
                    if skip_lstm: continue
                    # Shape of LSTM weights: [(2*neurons), (4*neurons)]
                    cell_template = np.ones((2, 4))
                    mask = np.repeat(layer_masks[i], v.shape[0] // 2, axis=0)
                    mask = mask.reshape(
                        [layer_masks[i].shape[0], v.shape[0] // 2])
                    mask = np.swapaxes(mask, 0, 1)
                    mask = np.kron(mask, cell_template)
                else:
                    idx = layer_masks[i] == 1
                    mask = np.repeat(layer_masks[i], v.shape[0], axis=0)
                    mask = mask.reshape([layer_masks[i].shape[0], v.shape[0]])
                    mask = np.swapaxes(mask, 0, 1)

                # apply mask to weights
                session.run(v.assign(tf.multiply(v, mask)))
                i += 1

        ###### END PRUNING PART ######

        def run_test(init_op, dataset):
            wav_filenames = []
            losses = []
            predictions = []
            ground_truths = []

            bar = create_progressbar(prefix='Test epoch | ',
                                     widgets=[
                                         'Steps: ',
                                         progressbar.Counter(), ' | ',
                                         progressbar.Timer()
                                     ]).start()
            log_progress('Test epoch...')

            step_count = 0

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # First pass, compute losses and transposed logits for decoding
            while True:
                try:
                    batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
                        session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
                except tf.errors.OutOfRangeError:
                    break

                decoded = ctc_beam_search_decoder_batch(
                    batch_logits,
                    batch_lengths,
                    Config.alphabet,
                    FLAGS.beam_width,
                    num_processes=num_processes,
                    scorer=scorer,
                    cutoff_prob=FLAGS.cutoff_prob,
                    cutoff_top_n=FLAGS.cutoff_top_n)
                predictions.extend(d[0][1] for d in decoded)
                ground_truths.extend(
                    sparse_tensor_value_to_texts(batch_transcripts,
                                                 Config.alphabet))
                wav_filenames.extend(
                    wav_filename.decode('UTF-8')
                    for wav_filename in batch_wav_filenames)
                losses.extend(batch_loss)

                step_count += 1
                bar.update(step_count)

            bar.finish()

            wer, cer, samples = calculate_report(wav_filenames, ground_truths,
                                                 predictions, losses)
            mean_loss = np.mean(losses)

            # Take only the first report_count items
            report_samples = itertools.islice(samples, FLAGS.report_count)

            if verbose:
                print('Test on %s - WER: %f, CER: %f, loss: %f' %
                      (dataset, wer, cer, mean_loss))
            if verbose: print('-' * 80)

            if result_file:
                pruning_type = 'score-based' if not random else 'random'
                result_string = '''Results for evaluating model with pruning percentage of {}% and {} pruning:
                Test on {} - WER: {}, CER: {}, loss: {}
                '''.format(prune_percentage * 100, pruning_type, dataset, wer,
                           cer, mean_loss)
                write_to_file(result_file, result_string, 'a+')

            return wer, cer, mean_loss

        results = []
        for csv, init_op in zip(test_csvs, test_init_ops):
            if verbose: print('Testing model on {}'.format(csv))
            results.extend(run_test(init_op, dataset=csv))
        return results
Exemple #7
0
def train():
    # Create training and validation datasets
    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               cache_path=FLAGS.train_cached_features_path)

    iterator = tf.data.Iterator.from_structure(train_set.output_types,
                                               train_set.output_shapes,
                                               output_classes=train_set.output_classes)

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_set = create_dataset(FLAGS.dev_files.split(','),
                                 batch_size=FLAGS.dev_batch_size,
                                 cache_path=FLAGS.dev_cached_features_path)
        dev_init_op = iterator.make_initializer(dev_set)

    # Dropout
    dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {
        rate: 0. for rate in dropout_rates
    }

    # Building the graph
    optimizer = create_optimizer()
    gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)

    # global_step is automagically incremented by the optimizer
    global_step = tf.train.get_or_create_global_step()
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

    # Summaries
    step_summaries_op = tf.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
        'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
    checkpoint_filename = 'checkpoint'

    best_dev_saver = tf.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
    best_dev_filename = 'best_dev_checkpoint'

    initializer = tf.global_variables_initializer()

    with tf.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        tf.get_default_graph().finalize()

        # Loading or initializing
        loaded = False
        if FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
        if not loaded:
            if FLAGS.load in ['auto', 'init']:
                log_info('Initializing variables...')
                session.run(initializer)
            else:
                log_error('Unable to load %s model from specified checkpoint dir'
                          ' - consider using load option "auto" or "init".' % FLAGS.load)
                sys.exit(1)

        def run_set(set_name, init_op):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data['mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

            if FLAGS.show_progressbar:
                pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
                                                        ' | ', progressbar.widgets.Timer(),
                                                        ' | Steps: ', progressbar.widgets.Counter(),
                                                        ' | ', LossWidget()])
                pbar.start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, step_summary = \
                        session.run([train_op, global_step, loss, step_summaries_op],
                                    feed_dict=feed_dict)
                except tf.errors.OutOfRangeError:
                    break

                total_loss += batch_loss
                step_count += 1

                if FLAGS.show_progressbar:
                    pbar.update(step_count)

                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
                    checkpoint_time = time.time()

            if FLAGS.show_progressbar:
                pbar.finish()

            return total_loss / step_count

        log_info('STARTING Optimization')
        best_dev_loss = float('inf')
        dev_losses = []
        try:
            for epoch in range(FLAGS.epochs):
                # Training
                if not FLAGS.show_progressbar:
                    log_info('Training epoch %d...' % epoch)
                train_loss = run_set('train', train_init_op)
                if not FLAGS.show_progressbar:
                    log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
                checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

                if FLAGS.dev_files:
                    # Validation
                    if not FLAGS.show_progressbar:
                        log_info('Validating epoch %d...' % epoch)
                    dev_loss = run_set('dev', dev_init_op)
                    if not FLAGS.show_progressbar:
                        log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
                    dev_losses.append(dev_loss)

                    if dev_loss < best_dev_loss:
                        best_dev_loss = dev_loss
                        save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
                        log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))

                    # Early stopping
                    if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
                        mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                        std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                        dev_losses = dev_losses[-FLAGS.es_steps:]
                        log_debug('Checking for early stopping (last %d steps) validation loss: '
                                  '%f, with standard deviation: %f and mean: %f' %
                                  (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
                        if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                           (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                            log_info('Early stop triggered as (for last %d steps) validation loss:'
                                     ' %f with standard deviation: %f and mean: %f' %
                                     (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
                            break
        except KeyboardInterrupt:
            pass
    log_debug('Session closed.')
Exemple #8
0
def train():
    do_cache_dataset = True

    # pylint: disable=too-many-boolean-expressions
    if (FLAGS.data_aug_features_multiplicative > 0 or
            FLAGS.data_aug_features_additive > 0 or
            FLAGS.augmentation_spec_dropout_keeprate < 1 or
            FLAGS.augmentation_freq_and_time_masking or
            FLAGS.augmentation_pitch_and_tempo_scaling or
            FLAGS.augmentation_speed_up_std > 0 or
            FLAGS.augmentation_sparse_warp):
        do_cache_dataset = False

    exception_box = ExceptionBox()

    # Create training and validation datasets
    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               enable_cache=FLAGS.feature_cache and do_cache_dataset,
                               cache_path=FLAGS.feature_cache,
                               train_phase=True,
                               exception_box=exception_box,
                               process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
                               buffering=FLAGS.read_buffer)

    iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
                                                 tfv1.data.get_output_shapes(train_set),
                                                 output_classes=tfv1.data.get_output_classes(train_set))

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_sources = FLAGS.dev_files.split(',')
        dev_sets = [create_dataset([source],
                                   batch_size=FLAGS.dev_batch_size,
                                   train_phase=False,
                                   exception_box=exception_box,
                                   process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
                                   buffering=FLAGS.read_buffer) for source in dev_sources]
        dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

    # Dropout
    dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {
        rate: 0. for rate in dropout_rates
    }

    # Building the graph
    learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
    reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
    optimizer = create_optimizer(learning_rate_var)

    # Enable mixed precision training
    if FLAGS.automatic_mixed_precision:
        log_info('Enabling automatic mixed precision training.')
        optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)

    gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)

    # global_step is automagically incremented by the optimizer
    global_step = tfv1.train.get_or_create_global_step()
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

    # Summaries
    step_summaries_op = tfv1.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
        'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')

    best_dev_saver = tfv1.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')

    # Save flags next to checkpoints
    os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
    flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
    with open(flags_file, 'w') as fout:
        fout.write(FLAGS.flags_into_string())

    with tfv1.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        # Prevent further graph changes
        tfv1.get_default_graph().finalize()

        # Load checkpoint or initialize variables
        if FLAGS.load == 'auto':
            method_order = ['best', 'last', 'init']
        else:
            method_order = [FLAGS.load]
        load_or_init_graph(session, method_order)

        def run_set(set_name, epoch, init_op, dataset=None):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            # Setup progress bar
            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data['mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

            prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
            widgets = [' | ', progressbar.widgets.Timer(),
                       ' | Steps: ', progressbar.widgets.Counter(),
                       ' | ', LossWidget()]
            suffix = ' | Dataset: {}'.format(dataset) if dataset else None
            pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, problem_files, step_summary = \
                        session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
                                    feed_dict=feed_dict)
                    exception_box.raise_if_set()
                except tf.errors.InvalidArgumentError as err:
                    if FLAGS.augmentation_sparse_warp:
                        log_info("Ignoring sparse warp error: {}".format(err))
                        continue
                    else:
                        raise
                except tf.errors.OutOfRangeError:
                    exception_box.raise_if_set()
                    break

                if problem_files.size > 0:
                    problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
                    log_error('The following files caused an infinite (or NaN) '
                              'loss: {}'.format(','.join(problem_files)))

                total_loss += batch_loss
                step_count += 1

                pbar.update(step_count)

                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
                    checkpoint_time = time.time()

            pbar.finish()
            mean_loss = total_loss / step_count if step_count > 0 else 0.0
            return mean_loss, step_count

        log_info('STARTING Optimization')
        train_start_time = datetime.utcnow()
        best_dev_loss = float('inf')
        dev_losses = []
        epochs_without_improvement = 0
        try:
            for epoch in range(FLAGS.epochs):
                # Training
                log_progress('Training epoch %d...' % epoch)
                train_loss, _ = run_set('train', epoch, train_init_op)
                log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
                checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

                if FLAGS.dev_files:
                    # Validation
                    dev_loss = 0.0
                    total_steps = 0
                    for source, init_op in zip(dev_sources, dev_init_ops):
                        log_progress('Validating epoch %d on %s...' % (epoch, source))
                        set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
                        dev_loss += set_loss * steps
                        total_steps += steps
                        log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))

                    dev_loss = dev_loss / total_steps
                    dev_losses.append(dev_loss)

                    # Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
                    # the improvement has to be greater than FLAGS.es_min_delta
                    if dev_loss > best_dev_loss - FLAGS.es_min_delta:
                        epochs_without_improvement += 1
                    else:
                        epochs_without_improvement = 0

                    # Save new best model
                    if dev_loss < best_dev_loss:
                        best_dev_loss = dev_loss
                        save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
                        log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))

                    # Early stopping
                    if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
                        log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
                            epochs_without_improvement))
                        break

                    # Reduce learning rate on plateau
                    if (FLAGS.reduce_lr_on_plateau and
                            epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
                        # If the learning rate was reduced and there is still no improvement
                        # wait FLAGS.plateau_epochs before the learning rate is reduced again
                        session.run(reduce_learning_rate_op)
                        current_learning_rate = learning_rate_var.eval()
                        log_info('Encountered a plateau, reducing learning rate to {}'.format(
                            current_learning_rate))

        except KeyboardInterrupt:
            pass
        log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
    log_debug('Session closed.')
Exemple #9
0
def evaluate(test_csvs, create_model, try_loading):
    scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path,
                    FLAGS.lm_trie_path, Config.alphabet)

    test_set = create_dataset(test_csvs,
                              batch_size=FLAGS.test_batch_size,
                              cache_path=FLAGS.test_cached_features_path)
    it = test_set.make_one_shot_iterator()

    (batch_x, batch_x_len), batch_y = it.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))

    loss = tf.nn.ctc_loss(labels=batch_y,
                          inputs=logits,
                          sequence_length=batch_x_len)

    global_step = tf.train.get_or_create_global_step()

    with tf.Session(config=Config.session_config) as session:
        # Create a saver using variables from the above newly created graph
        saver = tf.train.Saver()

        # Restore variables from training checkpoint
        loaded = try_loading(session, saver, 'best_dev_checkpoint',
                             'best validation')
        if not loaded:
            loaded = try_loading(session, saver, 'checkpoint', 'most recent')
        if not loaded:
            log_error(
                'Checkpoint directory ({}) does not contain a valid checkpoint state.'
                .format(FLAGS.checkpoint_dir))
            exit(1)

        logitses = []
        losses = []
        seq_lengths = []
        ground_truths = []

        print('Computing acoustic model predictions...')
        bar = progressbar.ProgressBar(widgets=[
            'Steps: ',
            progressbar.Counter(), ' | ',
            progressbar.Timer()
        ])

        step_count = 0

        # First pass, compute losses and transposed logits for decoding
        while True:
            try:
                logits, loss_, lengths, transcripts = session.run(
                    [transposed, loss, batch_x_len, batch_y])
            except tf.errors.OutOfRangeError:
                break

            step_count += 1
            bar.update(step_count)

            logitses.append(logits)
            losses.extend(loss_)
            seq_lengths.append(lengths)
            ground_truths.extend(
                sparse_tensor_value_to_texts(transcripts, Config.alphabet))

    bar.finish()

    predictions = []

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except:
        num_processes = 1

    print('Decoding predictions...')
    bar = progressbar.ProgressBar(max_value=step_count,
                                  widget=progressbar.AdaptiveETA)

    # Second pass, decode logits and compute WER and edit distance metrics
    for logits, seq_length in bar(zip(logitses, seq_lengths)):
        decoded = ctc_beam_search_decoder_batch(logits,
                                                seq_length,
                                                Config.alphabet,
                                                FLAGS.beam_width,
                                                num_processes=num_processes,
                                                scorer=scorer)
        predictions.extend(d[0][1] for d in decoded)

    distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]

    wer, cer, samples = calculate_report(ground_truths, predictions, distances,
                                         losses)
    mean_loss = np.mean(losses)

    # Take only the first report_count items
    report_samples = itertools.islice(samples, FLAGS.report_count)

    print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss))
    print('-' * 80)
    for sample in report_samples:
        print('WER: %f, CER: %f, loss: %f' %
              (sample.wer, sample.distance, sample.loss))
        print(' - src: "%s"' % sample.src)
        print(' - res: "%s"' % sample.res)
        print('-' * 80)

    return samples
Exemple #10
0
def evaluate(test_csvs, create_model):
    if FLAGS.scorer_path:
        scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path,
                        Config.alphabet)
    else:
        scorer = None

    test_csvs = FLAGS.test_files.split(',')
    test_sets = [
        create_dataset([csv],
                       batch_size=FLAGS.test_batch_size,
                       train_phase=False) for csv in test_csvs
    ]
    iterator = tfv1.data.Iterator.from_structure(
        tfv1.data.get_output_types(test_sets[0]),
        tfv1.data.get_output_shapes(test_sets[0]),
        output_classes=tfv1.data.get_output_classes(test_sets[0]))
    test_init_ops = [
        iterator.make_initializer(test_set) for test_set in test_sets
    ]

    batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             batch_size=FLAGS.test_batch_size,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))

    loss = tfv1.nn.ctc_loss(labels=batch_y,
                            inputs=logits,
                            sequence_length=batch_x_len)

    tfv1.train.get_or_create_global_step()

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except NotImplementedError:
        num_processes = 1

    with tfv1.Session(config=Config.session_config) as session:
        if FLAGS.load == 'auto':
            method_order = ['best', 'last']
        else:
            method_order = [FLAGS.load]
        load_or_init_graph(session, method_order)

        def run_test(init_op, dataset):
            wav_filenames = []
            losses = []
            predictions = []
            ground_truths = []

            bar = create_progressbar(prefix='Test epoch | ',
                                     widgets=[
                                         'Steps: ',
                                         progressbar.Counter(), ' | ',
                                         progressbar.Timer()
                                     ]).start()
            log_progress('Test epoch...')

            step_count = 0

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # First pass, compute losses and transposed logits for decoding
            while True:
                try:
                    batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
                        session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
                except tf.errors.OutOfRangeError:
                    break

                decoded = ctc_beam_search_decoder_batch(
                    batch_logits,
                    batch_lengths,
                    Config.alphabet,
                    FLAGS.beam_width,
                    num_processes=num_processes,
                    scorer=scorer,
                    cutoff_prob=FLAGS.cutoff_prob,
                    cutoff_top_n=FLAGS.cutoff_top_n)
                predictions.extend(d[0][1] for d in decoded)
                ground_truths.extend(
                    sparse_tensor_value_to_texts(batch_transcripts,
                                                 Config.alphabet))
                wav_filenames.extend(
                    wav_filename.decode('UTF-8')
                    for wav_filename in batch_wav_filenames)
                losses.extend(batch_loss)

                step_count += 1
                bar.update(step_count)

            bar.finish()

            # Print test summary
            test_samples = calculate_and_print_report(wav_filenames,
                                                      ground_truths,
                                                      predictions, losses,
                                                      dataset)
            return test_samples

        samples = []
        for csv, init_op in zip(test_csvs, test_init_ops):
            print('Testing model on {}'.format(csv))
            samples.extend(run_test(init_op, dataset=csv))
        return samples
Exemple #11
0
def train():
    # Create training and validation datasets
    train_set, train_batches = create_dataset(
        FLAGS.train_files.split(','),
        batch_size=FLAGS.train_batch_size,
        cache_path=FLAGS.train_cached_features_path)

    iterator = tf.data.Iterator.from_structure(
        train_set.output_types,
        train_set.output_shapes,
        output_classes=train_set.output_classes)

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_set, dev_batches = create_dataset(
            FLAGS.dev_files.split(','),
            batch_size=FLAGS.dev_batch_size,
            cache_path=FLAGS.dev_cached_features_path)
        dev_init_op = iterator.make_initializer(dev_set)

    # Dropout
    dropout_rates = [
        tf.placeholder(tf.float32, name='dropout_{}'.format(i))
        for i in range(6)
    ]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {rate: 0. for rate in dropout_rates}

    # Building the graph
    optimizer = create_optimizer()
    gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)
    # global_step is automagically incremented by the optimizer
    global_step = tf.Variable(0, trainable=False, name='global_step')
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients,
                                                  global_step=global_step)

    # Summaries
    step_summaries_op = tf.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'),
                              max_queue=120),
        'dev':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'),
                              max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
    checkpoint_filename = 'checkpoint'

    best_dev_saver = tf.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
    best_dev_filename = 'best_dev_checkpoint'

    initializer = tf.global_variables_initializer()

    with tf.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        tf.get_default_graph().finalize()

        # Loading or initializing
        loaded = False
        if FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session, checkpoint_saver,
                                 checkpoint_filename, 'most recent epoch')
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session, best_dev_saver, best_dev_filename,
                                 'best validation')
        if not loaded:
            if FLAGS.load in ['auto', 'init']:
                log_info('Initializing...')
                session.run(initializer)
            else:
                log_error(
                    'Unable to load %s model from specified checkpoint dir'
                    ' - consider using load option "auto" or "init".' %
                    FLAGS.load)
                sys.exit(1)

        # Retrieving global_step from restored model and setting training parameters accordingly
        step = session.run(global_step)
        num_gpus = len(Config.available_devices)
        steps_per_epoch = max(1, train_batches // num_gpus)
        current_epoch = step // steps_per_epoch
        target_epoch = current_epoch + abs(
            FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch

        log_debug('step: %d' % step)
        log_debug('epoch: %d' % current_epoch)
        log_debug('target epoch: %d' % target_epoch)
        log_debug('steps per epoch: %d' % steps_per_epoch)
        log_debug('batches per step (GPUs): %d' % num_gpus)
        log_debug('number of batches in train set: %d' % train_batches)

        def run_set(set_name, init_op, num_batches):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
            total_loss = 0.0
            step_summary_writer = step_summary_writers.get(set_name)
            num_steps = max(1, num_batches // num_gpus)
            checkpoint_time = time.time()

            if FLAGS.show_progressbar:
                pbar = progressbar.ProgressBar(max_value=num_steps,
                                               redirect_stdout=True).start()
            else:
                pbar = lambda i: i

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            for step_index in pbar(range(num_steps)):
                if coord.should_stop():
                    break

                _, current_step, batch_loss, step_summary = \
                    session.run([train_op, global_step, loss, step_summaries_op],
                                feed_dict=feed_dict)
                total_loss += batch_loss
                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time(
                ) - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session,
                                          checkpoint_path,
                                          global_step=current_step)
                    checkpoint_time = time.time()

            return total_loss / num_steps

        if target_epoch > current_epoch:
            log_info('STARTING Optimization')
            best_dev_loss = float('inf')
            dev_losses = []
            coord = tf.train.Coordinator()
            with coord.stop_on_exception():
                for current_epoch in range(current_epoch, target_epoch):
                    if coord.should_stop():
                        break

                    # Training
                    log_info('Training epoch %d ...' % current_epoch)
                    train_loss = run_set('train', train_init_op, train_batches)
                    log_info('Finished training epoch %d - loss: %f' %
                             (current_epoch, train_loss))
                    checkpoint_saver.save(session,
                                          checkpoint_path,
                                          global_step=global_step)

                    if FLAGS.dev_files:
                        # Validation
                        log_info('Validating epoch %d ...' % current_epoch)
                        dev_loss = run_set('dev', dev_init_op, dev_batches)
                        dev_losses.append(dev_loss)
                        log_info('Finished validating epoch %d - loss: %f' %
                                 (current_epoch, dev_loss))

                        if dev_loss < best_dev_loss:
                            best_dev_loss = dev_loss
                            save_path = best_dev_saver.save(
                                session,
                                best_dev_path,
                                latest_filename=best_dev_filename)
                            log_info(
                                "Saved new best validating model with loss %f to: %s"
                                % (best_dev_loss, save_path))

                        # Early stopping
                        if FLAGS.early_stop and len(
                                dev_losses) >= FLAGS.es_steps:
                            mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                            std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                            dev_losses = dev_losses[-FLAGS.es_steps:]
                            log_debug(
                                'Checking for early stopping (last %d steps) validation loss: '
                                '%f, with standard deviation: %f and mean: %f'
                                % (FLAGS.es_steps, dev_losses[-1], std_loss,
                                   mean_loss))
                            if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                               (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                                log_info(
                                    'Early stop triggered as (for last %d steps) validation loss:'
                                    ' %f with standard deviation: %f and mean: %f'
                                    % (FLAGS.es_steps, dev_losses[-1],
                                       std_loss, mean_loss))
                                break
                coord.request_stop()
        else:
            log_info('Target epoch already reached - skipped training.')
    log_debug('Session closed.')