def next_job(self, job):
        '''Sends a finished job back to the coordinator and retrieves in exchange the next one.

        Kwargs:
            job (WorkerJob): job that was finished by a worker and who's results are to be
                digested by the coordinator

        Returns:
            WorkerJob. next job of one of the running epochs that will get
                associated with the worker from the finished job and put into state 'running'
        '''
        if self.is_chief:
            # Try to find the epoch the job belongs to
            epoch = next((epoch for epoch in self._epochs_running if epoch.id == job.epoch_id), None)
            if epoch:
                # We are going to manipulate things - let's avoid undefined state
                with self._lock:
                    # Let the epoch finish the job
                    epoch.finish_job(job)
                    # Check, if epoch is done now
                    if epoch.done():
                        # If it declares itself done, move it from 'running' to 'done' collection
                        self._epochs_running.remove(epoch)
                        self._epochs_done.append(epoch)
                        log_info('%s' % epoch)
            else:
                # There was no running epoch found for this job - this should never happen.
                log_error('There is no running epoch of ID %d for job with ID %d. This should never happen.' % (job.epoch_id, job.id))
            return self.get_job(job.worker)

        # We are a remote worker and have to hand over to the chief worker by HTTP
        result = self._talk_to_chief('', data=pickle.dumps(job))
        if result:
            result = pickle.loads(result)
        return result
Beispiel #2
0
def main(_):
    initialize_globals()

    if not FLAGS.test_files:
        log_error('You need to specify what files to use for evaluation via '
                  'the --test_files flag.')
        exit(1)

    from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
    samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)

    if FLAGS.test_output_file:
        # Save decoded tuples as JSON, converting NumPy floats to Python floats
        json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def do_single_file_inference(input_file_path):
    with tf.Session(config=Config.session_config) as session:
        inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)

        # Create a saver using variables from the above newly created graph
        mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        # TODO: This restores the most recent checkpoint, but if we use validation to counteract
        #       over-fitting, we may want to restore an earlier checkpoint.
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if not checkpoint:
            log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
            exit(1)

        checkpoint_path = checkpoint.model_checkpoint_path
        saver.restore(session, checkpoint_path)
        session.run(outputs['initialize_state'])

        features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context)
        num_strides = len(features) - (Config.n_context * 2)

        # Create a view into the array with overlapping strides of size
        # numcontext (past) + 1 (present) + numcontext (future)
        window_size = 2*Config.n_context+1
        features = np.lib.stride_tricks.as_strided(
            features,
            (num_strides, window_size, Config.n_input),
            (features.strides[0], features.strides[0], features.strides[1]),
            writeable=False)

        logits = session.run(outputs['outputs'], feed_dict = {
            inputs['input']: [features],
            inputs['input_lengths']: [num_strides],
        })

        logits = np.squeeze(logits)

        scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
                        FLAGS.lm_binary_path, FLAGS.lm_trie_path,
                        Config.alphabet)
        decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer)
        # Print highest probability result
        print(decoded[0][1])
    def get_job(self, worker=0):
        '''Retrieves the first job for a worker.

        Kwargs:
            worker (int): index of the worker to get the first job for

        Returns:
            WorkerJob. a job of one of the running epochs that will get
                associated with the given worker and put into state 'running'
        '''
        # Let's ensure that this does not interfere with other workers/requests
        with self._lock:
            if self.is_chief:
                # First try to get a next job
                job = self._get_job(worker)

                if job is None:
                    # If there was no next job, we give it a second chance by triggering the epoch state machine
                    if self._next_epoch():
                        # Epoch state machine got a new epoch
                        # Second try to get a next job
                        job = self._get_job(worker)
                        if job is None:
                            # Albeit the epoch state machine got a new epoch, the epoch had no new job for us
                            log_error('Unexpected case - no job for worker %d.' % (worker))
                        return job

                    # Epoch state machine has no new epoch
                    # This happens at the end of the whole training - nothing to worry about
                    log_traffic('No jobs left for worker %d.' % (worker))
                    self._log_all_jobs()
                    return None

                # We got a new job from one of the currently running epochs
                log_traffic('Got new %s' % job)
                return job

            # We are a remote worker and have to hand over to the chief worker by HTTP
            result = self._talk_to_chief(PREFIX_GET_JOB + str(FLAGS.task_index))
            if result:
                result = pickle.loads(result)
            return result
Beispiel #5
0
def export():
    r'''
    Restores the trained variables into a simpler graph that will be exported for serving.
    '''
    log_info('Exporting the model...')
    with tf.device('/cpu:0'):
        from tensorflow.python.framework.ops import Tensor, Operation

        tf.reset_default_graph()
        session = tf.Session(config=Config.session_config)

        inputs, outputs, _ = create_inference_graph(
            batch_size=FLAGS.export_batch_size,
            n_steps=FLAGS.n_steps,
            tflite=FLAGS.export_tflite)
        input_names = ",".join(tensor.op.name for tensor in inputs.values())
        output_names_tensors = [
            tensor.op.name for tensor in outputs.values()
            if isinstance(tensor, Tensor)
        ]
        output_names_ops = [
            tensor.name for tensor in outputs.values()
            if isinstance(tensor, Operation)
        ]
        output_names = ",".join(output_names_tensors + output_names_ops)
        input_shapes = ":".join(",".join(map(str, tensor.shape))
                                for tensor in inputs.values())

        if not FLAGS.export_tflite:
            mapping = {
                v.op.name: v
                for v in tf.global_variables()
                if not v.op.name.startswith('previous_state_')
            }
        else:
            # Create a saver using variables from the above newly created graph
            def fixup(name):
                if name.startswith('rnn/lstm_cell/'):
                    return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
                return name

            mapping = {fixup(v.op.name): v for v in tf.global_variables()}

        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        checkpoint_path = checkpoint.model_checkpoint_path

        output_filename = 'output_graph.pb'
        if FLAGS.remove_export:
            if os.path.isdir(FLAGS.export_dir):
                log_info('Removing old export')
                shutil.rmtree(FLAGS.export_dir)
        try:
            output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

            if not os.path.isdir(FLAGS.export_dir):
                os.makedirs(FLAGS.export_dir)

            def do_graph_freeze(output_file=None,
                                output_node_names=None,
                                variables_blacklist=None):
                return freeze_graph.freeze_graph_with_def_protos(
                    input_graph_def=session.graph_def,
                    input_saver_def=saver.as_saver_def(),
                    input_checkpoint=checkpoint_path,
                    output_node_names=output_node_names,
                    restore_op_name=None,
                    filename_tensor_name=None,
                    output_graph=output_file,
                    clear_devices=False,
                    variable_names_blacklist=variables_blacklist,
                    initializer_nodes='')

            if not FLAGS.export_tflite:
                do_graph_freeze(
                    output_file=output_graph_path,
                    output_node_names=output_names,
                    variables_blacklist='previous_state_c,previous_state_h')
            else:
                frozen_graph = do_graph_freeze(output_node_names=output_names,
                                               variables_blacklist='')
                output_tflite_path = os.path.join(
                    FLAGS.export_dir,
                    output_filename.replace('.pb', '.tflite'))

                converter = tf.lite.TFLiteConverter(
                    frozen_graph,
                    input_tensors=inputs.values(),
                    output_tensors=outputs.values())
                converter.post_training_quantize = True
                tflite_model = converter.convert()

                with open(output_tflite_path, 'wb') as fout:
                    fout.write(tflite_model)

                log_info('Exported model for TF Lite engine as {}'.format(
                    os.path.basename(output_tflite_path)))

            log_info('Models exported at %s' % (FLAGS.export_dir))
        except RuntimeError as e:
            log_error(str(e))
Beispiel #6
0
def export():
    r'''
    Restores the trained variables into a simpler graph that will be exported for serving.
    '''
    log_info('Exporting the model...')
    with tf.device('/cpu:0'):
        from tensorflow.python.framework.ops import Tensor, Operation

        tf.reset_default_graph()
        session = tf.Session(config=Config.session_config)

        inputs, outputs, _ = create_inference_graph(batch_size=1,
                                                    n_steps=FLAGS.n_steps)
        input_names = ",".join(tensor.op.name for tensor in inputs.values())
        output_names_tensors = [
            tensor.op.name for tensor in outputs.values()
            if isinstance(tensor, Tensor)
        ]
        output_names_ops = [
            tensor.name for tensor in outputs.values()
            if isinstance(tensor, Operation)
        ]
        output_names = ",".join(output_names_tensors + output_names_ops)
        input_shapes = ":".join(",".join(map(str, tensor.shape))
                                for tensor in inputs.values())

        mapping = {
            v.op.name: v
            for v in tf.global_variables()
            if not v.op.name.startswith('previous_state_')
        }

        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        checkpoint_path = checkpoint.model_checkpoint_path

        output_filename = 'output_graph.pb'
        try:
            output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

            if not os.path.isdir(FLAGS.export_dir):
                os.makedirs(FLAGS.export_dir)

            def do_graph_freeze(output_file=None,
                                output_node_names=None,
                                variables_blacklist=None):
                freeze_graph.freeze_graph_with_def_protos(
                    input_graph_def=session.graph_def,
                    input_saver_def=saver.as_saver_def(),
                    input_checkpoint=checkpoint_path,
                    output_node_names=output_node_names,
                    restore_op_name=None,
                    filename_tensor_name=None,
                    output_graph=output_file,
                    clear_devices=False,
                    variable_names_blacklist=variables_blacklist,
                    initializer_nodes='')

            do_graph_freeze(
                output_file=output_graph_path,
                output_node_names=output_names,
                variables_blacklist='previous_state_c,previous_state_h')
            log_info('Models exported at %s' % (FLAGS.export_dir))
        except RuntimeError as e:
            log_error(str(e))
Beispiel #7
0
def initialize_globals():
    c = AttrDict()

    # ps and worker hosts required for p2p cluster setup
    FLAGS.ps_hosts = list(filter(len, FLAGS.ps_hosts.split(',')))
    FLAGS.worker_hosts = list(filter(len, FLAGS.worker_hosts.split(',')))

    # Create a cluster from the parameter server and worker hosts.
    c.cluster = tf.train.ClusterSpec({
        'ps': FLAGS.ps_hosts,
        'worker': FLAGS.worker_hosts
    })

    # The absolute number of computing nodes - regardless of cluster or single mode
    num_workers = max(1, len(FLAGS.worker_hosts))

    # If replica numbers are negative, we multiply their absolute values with the number of workers
    if FLAGS.replicas < 0:
        FLAGS.replicas = num_workers * -FLAGS.replicas
    if FLAGS.replicas_to_agg < 0:
        FLAGS.replicas_to_agg = num_workers * -FLAGS.replicas_to_agg

    # The device path base for this node
    c.worker_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task_index)

    # This node's CPU device
    c.cpu_device = c.worker_device + '/cpu:0'

    # This node's available GPU devices
    c.available_devices = [
        c.worker_device + gpu for gpu in get_available_gpus()
    ]

    # If there is no GPU available, we fall back to CPU based operation
    if 0 == len(c.available_devices):
        c.available_devices = [c.cpu_device]

    # Set default dropout rates
    if FLAGS.dropout_rate2 < 0:
        FLAGS.dropout_rate2 = FLAGS.dropout_rate
    if FLAGS.dropout_rate3 < 0:
        FLAGS.dropout_rate3 = FLAGS.dropout_rate
    if FLAGS.dropout_rate6 < 0:
        FLAGS.dropout_rate6 = FLAGS.dropout_rate

    # Set default checkpoint dir
    if len(FLAGS.checkpoint_dir) == 0:
        FLAGS.checkpoint_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'checkpoints'))

    if FLAGS.benchmark_steps > 0:
        FLAGS.checkpoint_dir = None

    # Set default summary dir
    if len(FLAGS.summary_dir) == 0:
        FLAGS.summary_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'summaries'))

    # Standard session configuration that'll be used for all new sessions.
    c.session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_placement,
        inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
        intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)

    c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

    # Geometric Constants
    # ===================

    # For an explanation of the meaning of the geometric constants, please refer to
    # doc/Geometry.md

    # Number of MFCC features
    c.n_input = 26  # TODO: Determine this programmatically from the sample rate

    # The number of frames in the context
    c.n_context = 9  # TODO: Determine the optimal value using a validation data set

    # Number of units in hidden layers
    c.n_hidden = FLAGS.n_hidden

    c.n_hidden_1 = c.n_hidden

    c.n_hidden_2 = c.n_hidden

    c.n_hidden_5 = c.n_hidden

    # LSTM cell state dimension
    c.n_cell_dim = c.n_hidden

    # The number of units in the third layer, which feeds in to the LSTM
    c.n_hidden_3 = c.n_cell_dim

    # Units in the sixth layer = number of characters in the target language plus one
    c.n_hidden_6 = c.alphabet.size() + 1  # +1 for CTC blank label

    # Queues that are used to gracefully stop parameter servers.
    # Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting.
    # Each ps will dequeue as many tokens as there are workers before joining/quitting.
    # This ensures parameter servers won't quit, if still required by at least one worker and
    # also won't wait forever (like with a standard `server.join()`).
    done_queues = []
    for i, ps in enumerate(FLAGS.ps_hosts):
        # Queues are hosted by their respective owners
        with tf.device('/job:ps/task:%d' % i):
            done_queues.append(
                tf.FIFOQueue(1, tf.int32, shared_name=('queue%i' % i)))

    # Placeholder to pass in the worker's index as token
    c.token_placeholder = tf.placeholder(tf.int32)

    # Enqueue operations for each parameter server
    c.done_enqueues = [
        queue.enqueue(c.token_placeholder) for queue in done_queues
    ]

    # Dequeue operations for each parameter server
    c.done_dequeues = [queue.dequeue() for queue in done_queues]

    if len(FLAGS.one_shot_infer) > 0:
        FLAGS.train = False
        FLAGS.test = False
        FLAGS.export_dir = ''
        if not os.path.exists(FLAGS.one_shot_infer):
            log_error(
                'Path specified in --one_shot_infer is not a valid file.')
            exit(1)

    # Determine, if we are the chief worker
    c.is_chief = len(
        FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0
                                     and FLAGS.job_name == 'worker')

    ConfigSingleton._config = c
Beispiel #8
0
def initialize_globals():
    c = AttrDict()

    # Set default dropout rates
    if FLAGS.dropout_rate2 < 0:
        FLAGS.dropout_rate2 = FLAGS.dropout_rate
    if FLAGS.dropout_rate3 < 0:
        FLAGS.dropout_rate3 = FLAGS.dropout_rate
    if FLAGS.dropout_rate6 < 0:
        FLAGS.dropout_rate6 = FLAGS.dropout_rate

    # Set default checkpoint dir
    if not FLAGS.checkpoint_dir:
        FLAGS.checkpoint_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'checkpoints'))

    if FLAGS.load not in ['last', 'best', 'init', 'auto', 'transfer']:
        FLAGS.load = 'auto'

    # Set default summary dir
    if not FLAGS.summary_dir:
        FLAGS.summary_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'summaries'))

    # Standard session configuration that'll be used for all new sessions.
    c.session_config = tfv1.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_placement,
        inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
        intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads,
        gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth))

    # CPU device
    c.cpu_device = '/cpu:0'

    # Available GPU devices
    c.available_devices = get_available_gpus(c.session_config)

    # If there is no GPU available, we fall back to CPU based operation
    if not c.available_devices:
        c.available_devices = [c.cpu_device]

    if FLAGS.utf8:
        c.alphabet = UTF8Alphabet()
    else:
        c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

    # Geometric Constants
    # ===================

    # For an explanation of the meaning of the geometric constants, please refer to
    # doc/Geometry.md

    # Number of MFCC features
    c.n_input = 26  # TODO: Determine this programmatically from the sample rate

    # The number of frames in the context
    c.n_context = 9  # TODO: Determine the optimal value using a validation data set

    # Number of units in hidden layers
    c.n_hidden = FLAGS.n_hidden

    c.n_hidden_1 = c.n_hidden

    c.n_hidden_2 = c.n_hidden

    c.n_hidden_5 = c.n_hidden

    # LSTM cell state dimension
    c.n_cell_dim = c.n_hidden

    # The number of units in the third layer, which feeds in to the LSTM
    c.n_hidden_3 = c.n_cell_dim

    # Units in the sixth layer = number of characters in the target language plus one
    c.n_hidden_6 = c.alphabet.size() + 1  # +1 for CTC blank label

    # Size of audio window in samples
    if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
        log_error(
            '--feature_win_len value ({}) in milliseconds ({}) multiplied '
            'by --audio_sample_rate value ({}) must be an integer value. Adjust '
            'your --feature_win_len value or resample your audio accordingly.'
            ''.format(FLAGS.feature_win_len, FLAGS.feature_win_len / 1000,
                      FLAGS.audio_sample_rate))
        sys.exit(1)

    c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len /
                                                        1000)

    # Stride for feature computations in samples
    if (FLAGS.feature_win_step * FLAGS.audio_sample_rate) % 1000 != 0:
        log_error(
            '--feature_win_step value ({}) in milliseconds ({}) multiplied '
            'by --audio_sample_rate value ({}) must be an integer value. Adjust '
            'your --feature_win_step value or resample your audio accordingly.'
            ''.format(FLAGS.feature_win_step, FLAGS.feature_win_step / 1000,
                      FLAGS.audio_sample_rate))
        sys.exit(1)

    c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step /
                                                      1000)

    if FLAGS.one_shot_infer:
        if not os.path.exists(FLAGS.one_shot_infer):
            log_error(
                'Path specified in --one_shot_infer is not a valid file.')
            sys.exit(1)

    ConfigSingleton._config = c  # pylint: disable=protected-access
Beispiel #9
0
def train():
    # Create training and validation datasets
    train_set, train_batches = create_dataset(
        FLAGS.train_files.split(','),
        batch_size=FLAGS.train_batch_size,
        cache_path=FLAGS.train_cached_features_path)

    iterator = tf.data.Iterator.from_structure(
        train_set.output_types,
        train_set.output_shapes,
        output_classes=train_set.output_classes)

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_set, dev_batches = create_dataset(
            FLAGS.dev_files.split(','),
            batch_size=FLAGS.dev_batch_size,
            cache_path=FLAGS.dev_cached_features_path)
        dev_init_op = iterator.make_initializer(dev_set)

    # Dropout
    dropout_rates = [
        tf.placeholder(tf.float32, name='dropout_{}'.format(i))
        for i in range(6)
    ]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {rate: 0. for rate in dropout_rates}

    # Building the graph
    optimizer = create_optimizer()
    gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)
    # global_step is automagically incremented by the optimizer
    global_step = tf.Variable(0, trainable=False, name='global_step')
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients,
                                                  global_step=global_step)

    # Summaries
    step_summaries_op = tf.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'),
                              max_queue=120),
        'dev':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'),
                              max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
    checkpoint_filename = 'checkpoint'

    best_dev_saver = tf.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
    best_dev_filename = 'best_dev_checkpoint'

    initializer = tf.global_variables_initializer()

    with tf.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        tf.get_default_graph().finalize()

        # Loading or initializing
        loaded = False
        if FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session, checkpoint_saver,
                                 checkpoint_filename, 'most recent epoch')
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session, best_dev_saver, best_dev_filename,
                                 'best validation')
        if not loaded:
            if FLAGS.load in ['auto', 'init']:
                log_info('Initializing...')
                session.run(initializer)
            else:
                log_error(
                    'Unable to load %s model from specified checkpoint dir'
                    ' - consider using load option "auto" or "init".' %
                    FLAGS.load)
                sys.exit(1)

        # Retrieving global_step from restored model and setting training parameters accordingly
        step = session.run(global_step)
        num_gpus = len(Config.available_devices)
        steps_per_epoch = max(1, train_batches // num_gpus)
        current_epoch = step // steps_per_epoch
        target_epoch = current_epoch + abs(
            FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch

        log_debug('step: %d' % step)
        log_debug('epoch: %d' % current_epoch)
        log_debug('target epoch: %d' % target_epoch)
        log_debug('steps per epoch: %d' % steps_per_epoch)
        log_debug('batches per step (GPUs): %d' % num_gpus)
        log_debug('number of batches in train set: %d' % train_batches)

        def run_set(set_name, init_op, num_batches):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
            total_loss = 0.0
            step_summary_writer = step_summary_writers.get(set_name)
            num_steps = max(1, num_batches // num_gpus)
            checkpoint_time = time.time()

            if FLAGS.show_progressbar:
                pbar = progressbar.ProgressBar(max_value=num_steps,
                                               redirect_stdout=True).start()
            else:
                pbar = lambda i: i

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            for step_index in pbar(range(num_steps)):
                if coord.should_stop():
                    break

                _, current_step, batch_loss, step_summary = \
                    session.run([train_op, global_step, loss, step_summaries_op],
                                feed_dict=feed_dict)
                total_loss += batch_loss
                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time(
                ) - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session,
                                          checkpoint_path,
                                          global_step=current_step)
                    checkpoint_time = time.time()

            return total_loss / num_steps

        if target_epoch > current_epoch:
            log_info('STARTING Optimization')
            best_dev_loss = float('inf')
            dev_losses = []
            coord = tf.train.Coordinator()
            with coord.stop_on_exception():
                for current_epoch in range(current_epoch, target_epoch):
                    if coord.should_stop():
                        break

                    # Training
                    log_info('Training epoch %d ...' % current_epoch)
                    train_loss = run_set('train', train_init_op, train_batches)
                    log_info('Finished training epoch %d - loss: %f' %
                             (current_epoch, train_loss))
                    checkpoint_saver.save(session,
                                          checkpoint_path,
                                          global_step=global_step)

                    if FLAGS.dev_files:
                        # Validation
                        log_info('Validating epoch %d ...' % current_epoch)
                        dev_loss = run_set('dev', dev_init_op, dev_batches)
                        dev_losses.append(dev_loss)
                        log_info('Finished validating epoch %d - loss: %f' %
                                 (current_epoch, dev_loss))

                        if dev_loss < best_dev_loss:
                            best_dev_loss = dev_loss
                            save_path = best_dev_saver.save(
                                session,
                                best_dev_path,
                                latest_filename=best_dev_filename)
                            log_info(
                                "Saved new best validating model with loss %f to: %s"
                                % (best_dev_loss, save_path))

                        # Early stopping
                        if FLAGS.early_stop and len(
                                dev_losses) >= FLAGS.es_steps:
                            mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                            std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                            dev_losses = dev_losses[-FLAGS.es_steps:]
                            log_debug(
                                'Checking for early stopping (last %d steps) validation loss: '
                                '%f, with standard deviation: %f and mean: %f'
                                % (FLAGS.es_steps, dev_losses[-1], std_loss,
                                   mean_loss))
                            if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                               (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                                log_info(
                                    'Early stop triggered as (for last %d steps) validation loss:'
                                    ' %f with standard deviation: %f and mean: %f'
                                    % (FLAGS.es_steps, dev_losses[-1],
                                       std_loss, mean_loss))
                                break
                coord.request_stop()
        else:
            log_info('Target epoch already reached - skipped training.')
    log_debug('Session closed.')
Beispiel #10
0
        def run_set(set_name, epoch, init_op, dataset=None):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            # Setup progress bar
            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data['mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

            prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
            widgets = [' | ', progressbar.widgets.Timer(),
                       ' | Steps: ', progressbar.widgets.Counter(),
                       ' | ', LossWidget()]
            suffix = ' | Dataset: {}'.format(dataset) if dataset else None
            pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, problem_files, step_summary = \
                        session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
                                    feed_dict=feed_dict)
                except tf.errors.InvalidArgumentError as err:
                    if FLAGS.augmentation_sparse_warp:
                        log_info("Ignoring sparse warp error: {}".format(err))
                        continue
                    else:
                        raise
                except tf.errors.OutOfRangeError:
                    break

                if problem_files.size > 0:
                    problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
                    log_error('The following files caused an infinite (or NaN) '
                              'loss: {}'.format(','.join(problem_files)))

                total_loss += batch_loss
                step_count += 1

                pbar.update(step_count)

                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
                    checkpoint_time = time.time()

            pbar.finish()
            mean_loss = total_loss / step_count if step_count > 0 else 0.0
            return mean_loss, step_count
Beispiel #11
0
def evaluate(test_csvs, create_model, try_loading):
    if FLAGS.lm_binary_path:
        scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path,
                        FLAGS.lm_trie_path, Config.alphabet)
    else:
        scorer = None

    test_csvs = FLAGS.test_files.split(',')
    test_sets = [
        create_dataset([csv],
                       batch_size=FLAGS.test_batch_size,
                       train_phase=False) for csv in test_csvs
    ]
    iterator = tfv1.data.Iterator.from_structure(
        tfv1.data.get_output_types(test_sets[0]),
        tfv1.data.get_output_shapes(test_sets[0]),
        output_classes=tfv1.data.get_output_classes(test_sets[0]))
    test_init_ops = [
        iterator.make_initializer(test_set) for test_set in test_sets
    ]

    batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             batch_size=FLAGS.test_batch_size,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))

    loss = tfv1.nn.ctc_loss(labels=batch_y,
                            inputs=logits,
                            sequence_length=batch_x_len)

    tfv1.train.get_or_create_global_step()

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except NotImplementedError:
        num_processes = 1

    # Create a saver using variables from the above newly created graph
    saver = tfv1.train.Saver()

    with tfv1.Session(config=Config.session_config) as session:
        # Restore variables from training checkpoint
        loaded = try_loading(session, saver, 'best_dev_checkpoint',
                             'best validation')
        if not loaded:
            loaded = try_loading(session, saver, 'checkpoint', 'most recent')
        if not loaded:
            log_error(
                'Checkpoint directory ({}) does not contain a valid checkpoint state.'
                .format(FLAGS.checkpoint_dir))
            exit(1)

        def run_test(init_op, dataset):
            wav_filenames = []
            losses = []
            predictions = []
            ground_truths = []

            bar = create_progressbar(prefix='Test epoch | ',
                                     widgets=[
                                         'Steps: ',
                                         progressbar.Counter(), ' | ',
                                         progressbar.Timer()
                                     ]).start()
            log_progress('Test epoch...')

            step_count = 0

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # First pass, compute losses and transposed logits for decoding
            while True:
                try:
                    batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
                        session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
                except tf.errors.OutOfRangeError:
                    break

                decoded = ctc_beam_search_decoder_batch(
                    batch_logits,
                    batch_lengths,
                    Config.alphabet,
                    FLAGS.beam_width,
                    num_processes=num_processes,
                    scorer=scorer,
                    cutoff_prob=FLAGS.cutoff_prob,
                    cutoff_top_n=FLAGS.cutoff_top_n)
                predictions.extend(d[0][1] for d in decoded)
                ground_truths.extend(
                    sparse_tensor_value_to_texts(batch_transcripts,
                                                 Config.alphabet))
                wav_filenames.extend(
                    wav_filename.decode('UTF-8')
                    for wav_filename in batch_wav_filenames)
                losses.extend(batch_loss)

                step_count += 1
                bar.update(step_count)

            bar.finish()

            wer, cer, samples = calculate_report(wav_filenames, ground_truths,
                                                 predictions, losses)
            mean_loss = np.mean(losses)

            # Take only the first report_count items
            report_samples = itertools.islice(samples, FLAGS.report_count)

            print('Test on %s - WER: %f, CER: %f, loss: %f' %
                  (dataset, wer, cer, mean_loss))
            print('-' * 80)
            for sample in report_samples:
                print('WER: %f, CER: %f, loss: %f' %
                      (sample.wer, sample.cer, sample.loss))
                print(' - wav: file://%s' % sample.wav_filename)
                print(' - src: "%s"' % sample.src)
                print(' - res: "%s"' % sample.res)
                print('-' * 80)

            return samples

        samples = []
        for csv, init_op in zip(test_csvs, test_init_ops):
            print('Testing model on {}'.format(csv))
            samples.extend(run_test(init_op, dataset=csv))
        return samples
Beispiel #12
0
def export():
    r'''
    Restores the trained variables into a simpler graph that will be exported for serving.
    '''
    log_info('Exporting the model...')
    with tf.device('/cpu:0'):
        from tensorflow.python.framework.ops import Tensor, Operation

        tf.reset_default_graph()
        session = tf.Session(config=Config.session_config)

        inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
        input_names = ",".join(tensor.op.name for tensor in inputs.values())
        output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ]
        output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ]
        output_names = ",".join(output_names_tensors + output_names_ops)
        input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())

        if not FLAGS.export_tflite:
            mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
        else:
            # Create a saver using variables from the above newly created graph
            def fixup(name):
                if name.startswith('rnn/lstm_cell/'):
                    return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
                return name

            mapping = {fixup(v.op.name): v for v in tf.global_variables()}

        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        checkpoint_path = checkpoint.model_checkpoint_path

        output_filename = 'output_graph.pb'
        if FLAGS.remove_export:
            if os.path.isdir(FLAGS.export_dir):
                log_info('Removing old export')
                shutil.rmtree(FLAGS.export_dir)
        try:
            output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

            if not os.path.isdir(FLAGS.export_dir):
                os.makedirs(FLAGS.export_dir)

            def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
                return freeze_graph.freeze_graph_with_def_protos(
                    input_graph_def=session.graph_def,
                    input_saver_def=saver.as_saver_def(),
                    input_checkpoint=checkpoint_path,
                    output_node_names=output_node_names,
                    restore_op_name=None,
                    filename_tensor_name=None,
                    output_graph=output_file,
                    clear_devices=False,
                    variable_names_blacklist=variables_blacklist,
                    initializer_nodes='')

            if not FLAGS.export_tflite:
                do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
            else:
                frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
                output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))

                converter = lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
                converter.post_training_quantize = True
                tflite_model = converter.convert()

                with open(output_tflite_path, 'wb') as fout:
                    fout.write(tflite_model)

                log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))

            log_info('Models exported at %s' % (FLAGS.export_dir))
        except RuntimeError as e:
            log_error(str(e))
Beispiel #13
0
def train():
    # Create training and validation datasets
    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               cache_path=FLAGS.feature_cache)

    iterator = tfv1.data.Iterator.from_structure(
        tfv1.data.get_output_types(train_set),
        tfv1.data.get_output_shapes(train_set),
        output_classes=tfv1.data.get_output_classes(train_set))

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_csvs = FLAGS.dev_files.split(',')
        dev_sets = [
            create_dataset([csv], batch_size=FLAGS.dev_batch_size)
            for csv in dev_csvs
        ]
        dev_init_ops = [
            iterator.make_initializer(dev_set) for dev_set in dev_sets
        ]

    # Dropout
    dropout_rates = [
        tfv1.placeholder(tf.float32, name='dropout_{}'.format(i))
        for i in range(6)
    ]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {rate: 0. for rate in dropout_rates}

    # Building the graph
    optimizer = create_optimizer()
    gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)

    # global_step is automagically incremented by the optimizer
    global_step = tfv1.train.get_or_create_global_step()
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients,
                                                  global_step=global_step)

    # Summaries
    step_summaries_op = tfv1.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train':
        tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'),
                                max_queue=120),
        'dev':
        tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'),
                                max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
    checkpoint_filename = 'checkpoint'

    best_dev_saver = tfv1.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
    best_dev_filename = 'best_dev_checkpoint'

    initializer = tfv1.global_variables_initializer()

    with tfv1.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        # Loading or initializing
        loaded = False

        # Initialize training from a CuDNN RNN checkpoint
        if FLAGS.cudnn_checkpoint:
            if FLAGS.use_cudnn_rnn:
                log_error(
                    'Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
                    'was specified. The --cudnn_checkpoint flag is only '
                    'needed when converting a CuDNN RNN checkpoint to '
                    'a CPU-capable graph. If your system is capable of '
                    'using CuDNN RNN, you can just specify the CuDNN RNN '
                    'checkpoint normally with --checkpoint_dir.')
                exit(1)

            log_info('Converting CuDNN RNN checkpoint from {}'.format(
                FLAGS.cudnn_checkpoint))
            ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
            missing_variables = []

            # Load compatible variables from checkpoint
            for v in tfv1.global_variables():
                try:
                    v.load(ckpt.get_tensor(v.op.name), session=session)
                except tf.errors.NotFoundError:
                    missing_variables.append(v)

            # Check that the only missing variables are the Adam moment tensors
            if any('Adam' not in v.op.name for v in missing_variables):
                log_error(
                    'Tried to load a CuDNN RNN checkpoint but there were '
                    'more missing variables than just the Adam moment '
                    'tensors.')
                exit(1)

            # Initialize Adam moment tensors from scratch to allow use of CuDNN
            # RNN checkpoints.
            log_info('Initializing missing Adam moment tensors.')
            init_op = tfv1.variables_initializer(missing_variables)
            session.run(init_op)
            loaded = True

        tfv1.get_default_graph().finalize()

        if not loaded and FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session, checkpoint_saver,
                                 checkpoint_filename, 'most recent')
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session, best_dev_saver, best_dev_filename,
                                 'best validation')
        if not loaded:
            if FLAGS.load in ['auto', 'init']:
                log_info('Initializing variables...')
                session.run(initializer)
            else:
                log_error(
                    'Unable to load %s model from specified checkpoint dir'
                    ' - consider using load option "auto" or "init".' %
                    FLAGS.load)
                sys.exit(1)

        def run_set(set_name, epoch, init_op, dataset=None):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            # Setup progress bar
            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(
                        self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data[
                        'mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(
                        self, progress, data, **kwargs)

            prefix = 'Epoch {} | {:>10}'.format(
                epoch, 'Training' if is_train else 'Validation')
            widgets = [
                ' | ',
                progressbar.widgets.Timer(), ' | Steps: ',
                progressbar.widgets.Counter(), ' | ',
                LossWidget()
            ]
            suffix = ' | Dataset: {}'.format(dataset) if dataset else None
            pbar = create_progressbar(prefix=prefix,
                                      widgets=widgets,
                                      suffix=suffix).start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, step_summary = \
                        session.run([train_op, global_step, loss, step_summaries_op],
                                    feed_dict=feed_dict)
                except tf.errors.OutOfRangeError:
                    break

                total_loss += batch_loss
                step_count += 1

                pbar.update(step_count)

                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time(
                ) - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session,
                                          checkpoint_path,
                                          global_step=current_step)
                    checkpoint_time = time.time()

            pbar.finish()
            mean_loss = total_loss / step_count if step_count > 0 else 0.0
            return mean_loss, step_count

        log_info('STARTING Optimization')
        train_start_time = datetime.utcnow()
        best_dev_loss = float('inf')
        dev_losses = []
        try:
            for epoch in range(FLAGS.epochs):
                # Training
                log_progress('Training epoch %d...' % epoch)
                train_loss, _ = run_set('train', epoch, train_init_op)
                log_progress('Finished training epoch %d - loss: %f' %
                             (epoch, train_loss))
                checkpoint_saver.save(session,
                                      checkpoint_path,
                                      global_step=global_step)

                if FLAGS.dev_files:
                    # Validation
                    dev_loss = 0.0
                    total_steps = 0
                    for csv, init_op in zip(dev_csvs, dev_init_ops):
                        log_progress('Validating epoch %d on %s...' %
                                     (epoch, csv))
                        set_loss, steps = run_set('dev',
                                                  epoch,
                                                  init_op,
                                                  dataset=csv)
                        dev_loss += set_loss * steps
                        total_steps += steps
                        log_progress(
                            'Finished validating epoch %d on %s - loss: %f' %
                            (epoch, csv, set_loss))
                    dev_loss = dev_loss / total_steps

                    dev_losses.append(dev_loss)

                    if dev_loss < best_dev_loss:
                        best_dev_loss = dev_loss
                        save_path = best_dev_saver.save(
                            session,
                            best_dev_path,
                            global_step=global_step,
                            latest_filename=best_dev_filename)
                        log_info(
                            "Saved new best validating model with loss %f to: %s"
                            % (best_dev_loss, save_path))

                    # Early stopping
                    if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
                        mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                        std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                        dev_losses = dev_losses[-FLAGS.es_steps:]
                        log_debug(
                            'Checking for early stopping (last %d steps) validation loss: '
                            '%f, with standard deviation: %f and mean: %f' %
                            (FLAGS.es_steps, dev_losses[-1], std_loss,
                             mean_loss))
                        if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                           (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                            log_info(
                                'Early stop triggered as (for last %d steps) validation loss:'
                                ' %f with standard deviation: %f and mean: %f'
                                % (FLAGS.es_steps, dev_losses[-1], std_loss,
                                   mean_loss))
                            break
        except KeyboardInterrupt:
            pass
        log_info('FINISHED optimization in {}'.format(datetime.utcnow() -
                                                      train_start_time))
    log_debug('Session closed.')
Beispiel #14
0
def train():
    r'''
    Trains the network on a given server of a cluster.
    If no server provided, it performs single process training.
    '''

    # Reading training set
    train_index = SampleIndex()

    train_data = preprocess(FLAGS.train_files.split(','),
                            FLAGS.train_batch_size,
                            Config.n_input,
                            Config.n_context,
                            Config.alphabet,
                            hdf5_cache_path=FLAGS.train_cached_features_path)

    train_set = DataSet(train_data,
                        FLAGS.train_batch_size,
                        limit=FLAGS.limit_train,
                        next_index=train_index.inc)

    # Reading validation set
    dev_index = SampleIndex()

    dev_data = preprocess(FLAGS.dev_files.split(','),
                          FLAGS.dev_batch_size,
                          Config.n_input,
                          Config.n_context,
                          Config.alphabet,
                          hdf5_cache_path=FLAGS.dev_cached_features_path)

    dev_set = DataSet(dev_data,
                      FLAGS.dev_batch_size,
                      limit=FLAGS.limit_dev,
                      next_index=dev_index.inc)

    # Combining all sets to a multi set model feeder
    model_feeder = ModelFeeder(train_set,
                               dev_set,
                               Config.n_input,
                               Config.n_context,
                               Config.alphabet,
                               tower_feeder_count=len(
                                   Config.available_devices))

    # Dropout
    dropout_rates = [
        tf.placeholder(tf.float32, name='dropout_{}'.format(i))
        for i in range(6)
    ]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {
        dropout_rates[0]: 0.,
        dropout_rates[1]: 0.,
        dropout_rates[2]: 0.,
        dropout_rates[3]: 0.,
        dropout_rates[4]: 0.,
        dropout_rates[5]: 0.,
    }

    # Building the graph
    optimizer = create_optimizer()
    gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates)
    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)
    # global_step is automagically incremented by the optimizer
    global_step = tf.Variable(0, trainable=False, name='global_step')
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients,
                                                  global_step=global_step)

    # Summaries
    step_summaries_op = tf.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'),
                              max_queue=120),
        'dev':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'),
                              max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
    checkpoint_filename = 'checkpoint'

    best_dev_saver = tf.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
    best_dev_filename = 'best_dev_checkpoint'

    initializer = tf.global_variables_initializer()

    with tf.Session(config=Config.session_config) as session:
        log_debug('Session opened.')
        tf.get_default_graph().finalize()

        # Loading or initializing
        loaded = False
        if FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session, checkpoint_saver,
                                 checkpoint_filename, 'most recent epoch')
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session, best_dev_saver, best_dev_filename,
                                 'best validation')
        if not loaded:
            if FLAGS.load in ['auto', 'init']:
                log_info('Initializing...')
                session.run(initializer)
            else:
                log_error(
                    'Unable to load %s model from specified checkpoint dir'
                    ' - consider using load option "auto" or "init".' %
                    FLAGS.load)
                sys.exit(1)

        # Retrieving global_step from restored model and setting training parameters accordingly
        model_feeder.set_data_set(no_dropout_feed_dict, train_set)
        step = session.run(global_step, feed_dict=no_dropout_feed_dict)
        num_gpus = len(Config.available_devices)
        steps_per_epoch = max(1, train_set.total_batches // num_gpus)
        steps_trained = step % steps_per_epoch
        current_epoch = step // steps_per_epoch
        target_epoch = current_epoch + abs(
            FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch
        train_index.index = steps_trained * num_gpus

        log_debug('step: %d' % step)
        log_debug('epoch: %d' % current_epoch)
        log_debug('target epoch: %d' % target_epoch)
        log_debug('steps per epoch: %d' % steps_per_epoch)
        log_debug('batches per step (GPUs): %d' % num_gpus)
        log_debug('number of batches in train set: %d' %
                  train_set.total_batches)
        log_debug('number of batches already trained in epoch: %d' %
                  train_index.index)

        def run_set(set_name):
            data_set = getattr(model_feeder, set_name)
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
            model_feeder.set_data_set(feed_dict, data_set)
            total_loss = 0.0
            step_summary_writer = step_summary_writers.get(set_name)
            num_steps = max(1, data_set.total_batches // num_gpus)
            checkpoint_time = time.time()
            if FLAGS.show_progressbar:
                pbar = progressbar.ProgressBar(max_value=num_steps,
                                               redirect_stdout=True).start()
            # Batch loop
            for step_index in range(steps_trained, num_steps):
                if coord.should_stop():
                    break
                _, current_step, batch_loss, step_summary = \
                    session.run([train_op, global_step, loss, step_summaries_op],
                                feed_dict=feed_dict)
                total_loss += batch_loss
                step_summary_writer.add_summary(step_summary, current_step)
                if FLAGS.show_progressbar:
                    pbar.update(step_index + 1, force=True)
                if is_train and FLAGS.checkpoint_secs > 0 and time.time(
                ) - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session,
                                          checkpoint_path,
                                          global_step=current_step)
                    checkpoint_time = time.time()
            if FLAGS.show_progressbar:
                pbar.finish()
            return total_loss / num_steps

        if target_epoch > current_epoch:
            log_info('STARTING Optimization')
            best_dev_loss = float('inf')
            dev_losses = []
            coord = tf.train.Coordinator()
            with coord.stop_on_exception():
                log_debug('Starting queue runners...')
                model_feeder.start_queue_threads(session, coord=coord)
                log_debug('Queue runners started.')
                # Epoch loop
                for current_epoch in range(current_epoch, target_epoch):
                    # Training
                    if coord.should_stop():
                        break
                    log_info('Training epoch %d ...' % current_epoch)
                    train_loss = run_set('train')
                    log_info('Finished training epoch %d - loss: %f' %
                             (current_epoch, train_loss))
                    checkpoint_saver.save(session,
                                          checkpoint_path,
                                          global_step=global_step)
                    steps_trained = 0
                    # Validation
                    log_info('Validating epoch %d ...' % current_epoch)
                    dev_loss = run_set('dev')
                    dev_losses.append(dev_loss)
                    log_info('Finished validating epoch %d - loss: %f' %
                             (current_epoch, dev_loss))
                    if dev_loss < best_dev_loss:
                        best_dev_loss = dev_loss
                        save_path = best_dev_saver.save(
                            session,
                            best_dev_path,
                            latest_filename=best_dev_filename)
                        log_info(
                            "Saved new best validating model with loss %f to: %s"
                            % (best_dev_loss, save_path))
                    # Early stopping
                    if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
                        mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                        std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                        dev_losses = dev_losses[-FLAGS.es_steps:]
                        log_debug(
                            'Checking for early stopping (last %d steps) validation loss: '
                            '%f, with standard deviation: %f and mean: %f' %
                            (FLAGS.es_steps, dev_losses[-1], std_loss,
                             mean_loss))
                        if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                           (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                            log_info(
                                'Early stop triggered as (for last %d steps) validation loss:'
                                ' %f with standard deviation: %f and mean: %f'
                                % (FLAGS.es_steps, dev_losses[-1], std_loss,
                                   mean_loss))
                            break
                log_debug('Closing queues...')
                coord.request_stop()
                model_feeder.close_queues(session)
                log_debug('Queues closed.')
        else:
            log_info('Target epoch already reached - skipped training.')
    log_debug('Session closed.')
Beispiel #15
0
def initialize_globals():
    c = AttrDict()

    # ps and worker hosts required for p2p cluster setup
    FLAGS.ps_hosts = list(filter(len, FLAGS.ps_hosts.split(',')))
    FLAGS.worker_hosts = list(filter(len, FLAGS.worker_hosts.split(',')))

    # Create a cluster from the parameter server and worker hosts.
    c.cluster = tf.train.ClusterSpec({'ps': FLAGS.ps_hosts, 'worker': FLAGS.worker_hosts})

    # The absolute number of computing nodes - regardless of cluster or single mode
    num_workers = max(1, len(FLAGS.worker_hosts))

    # If replica numbers are negative, we multiply their absolute values with the number of workers
    if FLAGS.replicas < 0:
        FLAGS.replicas = num_workers * -FLAGS.replicas
    if FLAGS.replicas_to_agg < 0:
        FLAGS.replicas_to_agg = num_workers * -FLAGS.replicas_to_agg

    # The device path base for this node
    c.worker_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task_index)

    # This node's CPU device
    c.cpu_device = c.worker_device + '/cpu:0'

    # This node's available GPU devices
    c.available_devices = [c.worker_device + gpu for gpu in get_available_gpus()]

    # If there is no GPU available, we fall back to CPU based operation
    if 0 == len(c.available_devices):
        c.available_devices = [c.cpu_device]

    # Set default dropout rates
    if FLAGS.dropout_rate2 < 0:
        FLAGS.dropout_rate2 = FLAGS.dropout_rate
    if FLAGS.dropout_rate3 < 0:
        FLAGS.dropout_rate3 = FLAGS.dropout_rate
    if FLAGS.dropout_rate6 < 0:
        FLAGS.dropout_rate6 = FLAGS.dropout_rate

    # Set default checkpoint dir
    if len(FLAGS.checkpoint_dir) == 0:
        FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints'))

    # Set default summary dir
    if len(FLAGS.summary_dir) == 0:
        FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries'))

    # Standard session configuration that'll be used for all new sessions.
    c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
                                      inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
                                      intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)

    c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

    # Geometric Constants
    # ===================

    # For an explanation of the meaning of the geometric constants, please refer to
    # doc/Geometry.md

    # Number of MFCC features
    c.n_input = 26 # TODO: Determine this programmatically from the sample rate

    # The number of frames in the context
    c.n_context = 9 # TODO: Determine the optimal value using a validation data set

    # Number of units in hidden layers
    c.n_hidden = FLAGS.n_hidden

    c.n_hidden_1 = c.n_hidden

    c.n_hidden_2 = c.n_hidden

    c.n_hidden_5 = c.n_hidden

    # LSTM cell state dimension
    c.n_cell_dim = c.n_hidden

    # The number of units in the third layer, which feeds in to the LSTM
    c.n_hidden_3 = c.n_cell_dim

    # Units in the sixth layer = number of characters in the target language plus one
    c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label

    # Queues that are used to gracefully stop parameter servers.
    # Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting.
    # Each ps will dequeue as many tokens as there are workers before joining/quitting.
    # This ensures parameter servers won't quit, if still required by at least one worker and
    # also won't wait forever (like with a standard `server.join()`).
    done_queues = []
    for i, ps in enumerate(FLAGS.ps_hosts):
        # Queues are hosted by their respective owners
        with tf.device('/job:ps/task:%d' % i):
            done_queues.append(tf.FIFOQueue(1, tf.int32, shared_name=('queue%i' % i)))

    # Placeholder to pass in the worker's index as token
    c.token_placeholder = tf.placeholder(tf.int32)

    # Enqueue operations for each parameter server
    c.done_enqueues = [queue.enqueue(c.token_placeholder) for queue in done_queues]

    # Dequeue operations for each parameter server
    c.done_dequeues = [queue.dequeue() for queue in done_queues]

    if len(FLAGS.one_shot_infer) > 0:
        FLAGS.train = False
        FLAGS.test = False
        FLAGS.export_dir = ''
        if not os.path.exists(FLAGS.one_shot_infer):
            log_error('Path specified in --one_shot_infer is not a valid file.')
            exit(1)

    # Determine, if we are the chief worker
    c.is_chief = len(FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0 and FLAGS.job_name == 'worker')

    ConfigSingleton._config = c
def train(server=None):
    r'''
    Trains the network on a given server of a cluster.
    If no server provided, it performs single process training.
    '''

    # The transfer learning approach here need us to supply the layers which we
    # want to exclude from the source model.
    # Say we want to exclude all layers except for the first one, we can use this:
    #
    #    drop_source_layers=['2', '3', 'lstm', '5', '6']
    #
    # If we want to use all layers from the source model except the last one, we use this:
    #
    #    drop_source_layers=['6']
    #

    drop_source_layers = ['2', '3', 'lstm', '5',
                          '6'][-int(FLAGS.drop_source_layers):]

    # Initializing and starting the training coordinator
    coord = TrainingCoordinator(Config.is_chief)
    coord.start()

    # Create a variable to hold the global_step.
    # It will automagically get incremented by the optimizer.
    global_step = tf.Variable(0, trainable=False, name='global_step')

    dropout_rates = [
        tf.placeholder(tf.float32, name='dropout_{}'.format(i))
        for i in range(6)
    ]

    # Reading training set
    train_data = preprocess(FLAGS.train_files.split(','),
                            FLAGS.train_batch_size,
                            Config.n_input,
                            Config.n_context,
                            Config.alphabet,
                            hdf5_cache_path=FLAGS.train_cached_features_path)

    train_set = DataSet(train_data,
                        FLAGS.train_batch_size,
                        limit=FLAGS.limit_train,
                        next_index=lambda i: coord.get_next_index('train'))

    # Reading validation set
    dev_data = preprocess(FLAGS.dev_files.split(','),
                          FLAGS.dev_batch_size,
                          Config.n_input,
                          Config.n_context,
                          Config.alphabet,
                          hdf5_cache_path=FLAGS.dev_cached_features_path)

    dev_set = DataSet(dev_data,
                      FLAGS.dev_batch_size,
                      limit=FLAGS.limit_dev,
                      next_index=lambda i: coord.get_next_index('dev'))

    # Combining all sets to a multi set model feeder
    model_feeder = ModelFeeder(train_set,
                               dev_set,
                               Config.n_input,
                               Config.n_context,
                               Config.alphabet,
                               tower_feeder_count=len(
                                   Config.available_devices))

    # Create the optimizer
    optimizer = create_optimizer()

    # Synchronous distributed training is facilitated by a special proxy-optimizer
    if not server is None:
        optimizer = tf.train.SyncReplicasOptimizer(
            optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_agg,
            total_num_replicas=FLAGS.replicas)

    # Get the data_set specific graph end-points
    gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates,
                                        drop_source_layers)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)

    # Add summaries of all variables and gradients to log
    log_grads_and_vars(avg_tower_gradients)

    # Op to merge all summaries for the summary hook
    merge_all_summaries_op = tf.summary.merge_all()

    # These are saved on every step
    step_summaries_op = tf.summary.merge_all('step_summaries')

    step_summary_writers = {
        'train':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'),
                              max_queue=120),
        'dev':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'),
                              max_queue=120)
    }

    # Apply gradients to modify the model
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients,
                                                  global_step=global_step)

    if FLAGS.early_stop is True and not FLAGS.validation_step > 0:
        log_warn(
            'Parameter --validation_step needs to be >0 for early stopping to work'
        )

    class CoordHook(tf.train.SessionRunHook):
        r'''
        Embedded coordination hook-class that will use variables of the
        surrounding Python context.
        '''
        def after_create_session(self, session, coord):
            log_debug('Starting queue runners...')
            model_feeder.start_queue_threads(session, coord)
            log_debug('Queue runners started.')

        def end(self, session):
            # Closing the data_set queues
            log_debug('Closing queues...')
            model_feeder.close_queues(session)
            log_debug('Queues closed.')

            # Telling the ps that we are done
            send_token_to_ps(session)

    # Collecting the hooks
    hooks = [CoordHook()]

    # Hook to handle initialization and queues for sync replicas.
    if not server is None:
        hooks.append(optimizer.make_session_run_hook(Config.is_chief))

    # Hook to save TensorBoard summaries
    if FLAGS.summary_secs > 0:
        hooks.append(
            tf.train.SummarySaverHook(save_secs=FLAGS.summary_secs,
                                      output_dir=FLAGS.summary_dir,
                                      summary_op=merge_all_summaries_op))

    # Hook wih number of checkpoint files to save in checkpoint_dir
    if FLAGS.train and FLAGS.max_to_keep > 0:
        saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
        hooks.append(
            tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir,
                                         save_secs=FLAGS.checkpoint_secs,
                                         saver=saver))

    no_dropout_feed_dict = {
        dropout_rates[0]: 0.,
        dropout_rates[1]: 0.,
        dropout_rates[2]: 0.,
        dropout_rates[3]: 0.,
        dropout_rates[4]: 0.,
        dropout_rates[5]: 0.,
    }

    # Progress Bar
    def update_progressbar(set_name):
        if not hasattr(update_progressbar, 'current_set_name'):
            update_progressbar.current_set_name = None

        if (update_progressbar.current_set_name != set_name
                or update_progressbar.current_job_index
                == update_progressbar.total_jobs):

            # finish prev pbar if it exists
            if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar:
                update_progressbar.pbar.finish()

            update_progressbar.total_jobs = None
            update_progressbar.current_job_index = 0

            current_epoch = coord._epoch - 1
            sufix = "graph_noisySVA_CV_2layers_"
            checkpoint_stash = "/docker_files/ckpt_stash/"
            checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            checkpoint_path = checkpoint.model_checkpoint_path
            ckpt_dest_name = sufix + str(current_epoch - 118) + "_eph"
            str_to_replace = "s/" + checkpoint_path.split(
                '/')[-1] + "/" + ckpt_dest_name + "/"

            subprocess.Popen(
                ["cp", checkpoint_path + ".meta", checkpoint_stash])
            #pdb.set_trace()
            subprocess.Popen([
                "rename", str_to_replace,
                checkpoint_stash + checkpoint_path.split('/')[-1] + ".meta"
            ])

            subprocess.Popen([
                "cp", checkpoint_path + ".data-00000-of-00001",
                checkpoint_stash
            ])
            subprocess.Popen([
                "rename", str_to_replace, checkpoint_stash +
                checkpoint_path.split('/')[-1] + ".data-00000-of-00001"
            ])

            subprocess.Popen(
                ["cp", checkpoint_path + ".index", checkpoint_stash])
            subprocess.Popen([
                "rename", str_to_replace,
                checkpoint_stash + checkpoint_path.split('/')[-1] + ".index"
            ])

            #HERE

            if set_name == "train":
                log_info('Training epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_train
            else:
                log_info('Validating epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_dev

            # recreate pbar
            update_progressbar.pbar = progressbar.ProgressBar(
                max_value=update_progressbar.total_jobs,
                redirect_stdout=True).start()

            update_progressbar.current_set_name = set_name

        if update_progressbar.pbar:
            update_progressbar.pbar.update(
                update_progressbar.current_job_index + 1, force=True)

        update_progressbar.current_job_index += 1

    # Initialize update_progressbar()'s child fields to safe values
    update_progressbar.pbar = None

    ### TRANSFER LEARNING ###
    def init_fn(scaffold, session):
        if FLAGS.source_model_checkpoint_dir:
            drop_source_layers.append('layer_6')
            print('Initializing from', FLAGS.source_model_checkpoint_dir)
            ckpt = tf.train.load_checkpoint(FLAGS.source_model_checkpoint_dir)
            variables = list(ckpt.get_variable_to_shape_map().keys())
            for v in tf.global_variables():
                if not any(layer in v.op.name for layer in drop_source_layers):
                    #if not v.name.count('b6') or not v.name.count('h6') or not v.name.count('raw_logits'):
                    with open("/data/german_DS/deepspeech-german/nodes.txt",
                              "w") as nodetxtfile:
                        print('Loading', v.op.name)
                        nodetxtfile.write(v.op.name)
                        v.load(ckpt.get_tensor(v.op.name), session=session)

    scaffold = tf.train.Scaffold(
        init_op=tf.variables_initializer([
            v for v in tf.global_variables()
            if any(layer in v.op.name for layer in drop_source_layers)
        ]  #or v.name.count('b6')]
                                         ),
        init_fn=init_fn)
    ### TRANSFER LEARNING ###

    pdb.set_trace()
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    try:
        with tf.train.MonitoredTrainingSession(
                master='' if server is None else server.target,
                is_chief=Config.is_chief,
                hooks=hooks,
                scaffold=scaffold,  # transfer-learning
                checkpoint_dir=FLAGS.checkpoint_dir,
                save_checkpoint_secs=None,  # already taken care of by a hook
                log_step_count_steps=
                0,  # disable logging of steps/s to avoid TF warning in validation sets
                config=Config.session_config) as session:
            #tf.get_default_graph().finalize()
            #do_export = False
            try:
                if Config.is_chief:
                    # Retrieving global_step from the (potentially restored) model
                    model_feeder.set_data_set(no_dropout_feed_dict,
                                              model_feeder.train)
                    step = session.run(global_step,
                                       feed_dict=no_dropout_feed_dict)
                    coord.start_coordination(model_feeder, step)
                    #if do_export:
                    #export(session)
                    #print("########INDISE EXPORT###########")
                    #do_export = True

                # Get the first job
                job = coord.get_job()

                while job and not session.should_stop():
                    log_debug('Computing %s...' % job)

                    is_train = job.set_name == 'train'

                    # The feed_dict (mainly for switching between queues)
                    if is_train:
                        feed_dict = {
                            dropout_rates[0]: FLAGS.dropout_rate,
                            dropout_rates[1]: FLAGS.dropout_rate2,
                            dropout_rates[2]: FLAGS.dropout_rate3,
                            dropout_rates[3]: FLAGS.dropout_rate4,
                            dropout_rates[4]: FLAGS.dropout_rate5,
                            dropout_rates[5]: FLAGS.dropout_rate6,
                        }
                    else:
                        feed_dict = no_dropout_feed_dict

                    # Sets the current data_set for the respective placeholder in feed_dict
                    model_feeder.set_data_set(
                        feed_dict, getattr(model_feeder, job.set_name))

                    # Initialize loss aggregator
                    total_loss = 0.0

                    # Setting the training operation in case of training requested
                    train_op = apply_gradient_op if is_train else []

                    # So far the only extra parameter is the feed_dict
                    extra_params = {'feed_dict': feed_dict}

                    step_summary_writer = step_summary_writers.get(
                        job.set_name)

                    # Loop over the batches
                    for job_step in range(job.steps):
                        if session.should_stop():
                            break

                        log_debug('Starting batch...')
                        # Compute the batch
                        _, current_step, batch_loss, step_summary = session.run(
                            [train_op, global_step, loss, step_summaries_op],
                            **extra_params)

                        # Log step summaries
                        step_summary_writer.add_summary(
                            step_summary, current_step)

                        # Uncomment the next line for debugging race conditions / distributed TF
                        log_debug('Finished batch step %d.' % current_step)

                        # Add batch to loss
                        total_loss += batch_loss

                    # Gathering job results
                    job.loss = total_loss / job.steps

                    # Display progressbar
                    if FLAGS.show_progressbar:
                        update_progressbar(job.set_name)

                    # Send the current job to coordinator and receive the next one
                    log_debug('Sending %s...' % job)
                    job = coord.next_job(job)

                if update_progressbar.pbar:
                    update_progressbar.pbar.finish()

#export()
#mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
#saver = tf.train.Saver(mapping)
#def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
#    freeze_graph.freeze_graph_with_def_protos(
#       input_graph_def=session.graph_def,
#        input_saver_def=saver.as_saver_def(),
#        input_checkpoint=checkpoint_path,
#        output_node_names=output_node_names,
#        restore_op_name=None,
#        filename_tensor_name=None,
#        output_graph=output_file,
#        clear_devices=False,
#        variable_names_blacklist=variables_blacklist,
#        initializer_nodes='')
#output_graph_path = "output_graph.pb"
#do_graph_freeze(output_file=output_graph_path, output_node_names='logits,initialize_state', variables_blacklist='previous_state_c,previous_state_h')

            except Exception as e:
                log_error(str(e))
                traceback.print_exc()
                # Calling all hook's end() methods to end blocking calls
                for hook in hooks:
                    hook.end(session)
                # Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit.
                # A rather graceful way to do this is by stopping the ps.
                # Only one party can send it w/o failing.
                if Config.is_chief:
                    send_token_to_ps(session, kill=True)
                sys.exit(1)

        log_debug('Session closed.')

    except tf.errors.InvalidArgumentError as e:
        log_error(str(e))
        log_error(
            'The checkpoint in {0} does not match the shapes of the model.'
            ' Did you change alphabet.txt or the --n_hidden parameter'
            ' between train runs using the same checkpoint dir? Try moving'
            ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
        sys.exit(1)

    # Stopping the coordinator
    coord.stop()
Beispiel #17
0
def initialize_globals():
    c = AttrDict()

    # CPU device
    c.cpu_device = '/cpu:0'

    # Available GPU devices
    c.available_devices = get_available_gpus()

    # If there is no GPU available, we fall back to CPU based operation
    if 0 == len(c.available_devices):
        c.available_devices = [c.cpu_device]

    # Set default dropout rates
    if FLAGS.dropout_rate2 < 0:
        FLAGS.dropout_rate2 = FLAGS.dropout_rate
    if FLAGS.dropout_rate3 < 0:
        FLAGS.dropout_rate3 = FLAGS.dropout_rate
    if FLAGS.dropout_rate6 < 0:
        FLAGS.dropout_rate6 = FLAGS.dropout_rate

    # Set default checkpoint dir
    if len(FLAGS.checkpoint_dir) == 0:
        FLAGS.checkpoint_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'checkpoints'))

    if FLAGS.load not in ['last', 'best', 'init', 'auto']:
        FLAGS.load = 'auto'

    # Set default summary dir
    if len(FLAGS.summary_dir) == 0:
        FLAGS.summary_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'summaries'))

    # Standard session configuration that'll be used for all new sessions.
    c.session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_placement,
        inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
        intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)

    c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

    # Geometric Constants
    # ===================

    # For an explanation of the meaning of the geometric constants, please refer to
    # doc/Geometry.md

    # Number of MFCC features
    c.n_input = 26  # TODO: Determine this programmatically from the sample rate

    # The number of frames in the context
    c.n_context = 9  # TODO: Determine the optimal value using a validation data set

    # Number of units in hidden layers
    c.n_hidden = FLAGS.n_hidden

    c.n_hidden_1 = c.n_hidden

    c.n_hidden_2 = c.n_hidden

    c.n_hidden_5 = c.n_hidden

    # LSTM cell state dimension
    c.n_cell_dim = c.n_hidden

    # The number of units in the third layer, which feeds in to the LSTM
    c.n_hidden_3 = c.n_cell_dim

    # Units in the sixth layer = number of characters in the target language plus one
    c.n_hidden_6 = c.alphabet.size() + 1  # +1 for CTC blank label

    if len(FLAGS.one_shot_infer) > 0:
        FLAGS.train = False
        FLAGS.test = False
        FLAGS.export_dir = ''
        if not os.path.exists(FLAGS.one_shot_infer):
            log_error(
                'Path specified in --one_shot_infer is not a valid file.')
            exit(1)

    ConfigSingleton._config = c
Beispiel #18
0
def evaluate(test_data, inference_graph):
    scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path,
                    FLAGS.lm_trie_path, Config.alphabet)

    def create_windows(features):
        num_strides = len(features) - (Config.n_context * 2)

        # Create a view into the array with overlapping strides of size
        # numcontext (past) + 1 (present) + numcontext (future)
        window_size = 2 * Config.n_context + 1
        features = np.lib.stride_tricks.as_strided(
            features, (num_strides, window_size, Config.n_input),
            (features.strides[0], features.strides[0], features.strides[1]),
            writeable=False)

        return features

    # Create overlapping windows over the features
    test_data['features'] = test_data['features'].apply(create_windows)

    with tf.Session(config=Config.session_config) as session:
        inputs, outputs, layers = inference_graph
        layer_4 = layers['rnn_output']
        layer_5 = layers['layer_5']
        layer_6 = layers['layer_6']
        # Transpose to batch major for decoder
        transposed = tf.transpose(outputs['outputs'], [1, 0, 2])

        labels_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size, None],
                                   name="labels")
        label_lengths_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size],
                                          name="label_lengths")

        # We add 1 to all elements of the transcript to avoid any zero values
        # since we use that as an end-of-sequence token for converting the batch
        # into a SparseTensor. So here we convert the placeholder back into a
        # SparseTensor and subtract ones to get the real labels.
        sparse_labels = tf.contrib.layers.dense_to_sparse(labels_ph)
        neg_ones = tf.SparseTensor(sparse_labels.indices,
                                   -1 * tf.ones_like(sparse_labels.values),
                                   sparse_labels.dense_shape)
        sparse_labels = tf.sparse_add(sparse_labels, neg_ones)

        loss = tf.nn.ctc_loss(labels=sparse_labels,
                              inputs=layers['raw_logits'],
                              sequence_length=inputs['input_lengths'])

        # Create a saver using variables from the above newly created graph
        mapping = {
            v.op.name: v
            for v in tf.global_variables()
            if not v.op.name.startswith('previous_state_')
        }
        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if not checkpoint:
            log_error(
                'Checkpoint directory ({}) does not contain a valid checkpoint state.'
                .format(FLAGS.checkpoint_dir))
            exit(1)

        checkpoint_path = checkpoint.model_checkpoint_path
        saver.restore(session, checkpoint_path)

        logitses = []
        losses = []
        ## To Print the embeddings
        layer_4s = []
        layer_5s = []
        layer_6s = []

        print('Computing acoustic model predictions...')
        batch_count = len(test_data) // FLAGS.test_batch_size
        print('Batch Count: ', batch_count)
        bar = progressbar.ProgressBar(max_value=batch_count,
                                      widget=progressbar.AdaptiveETA)

        # First pass, compute losses and transposed logits for decoding
        for batch in bar(split_data(test_data, FLAGS.test_batch_size)):
            session.run(outputs['initialize_state'])
            #TODO: Need to remove it to generalize for greater batch size!
            assert FLAGS.test_batch_size == 1, 'Embedding Extraction will only work for Batch Size = 1 for now!'

            features = pad_to_dense(batch['features'].values)
            features_len = batch['features_len'].values
            labels = pad_to_dense(batch['transcript'].values + 1)
            label_lengths = batch['transcript_len'].values

            logits, loss_, lay4, lay5, lay6 = session.run(
                [transposed, loss, layer_4, layer_5, layer_6],
                feed_dict={
                    inputs['input']: features,
                    inputs['input_lengths']: features_len,
                    labels_ph: labels,
                    label_lengths_ph: label_lengths
                })

            logitses.append(logits)
            losses.extend(loss_)
            layer_4s.append(lay4)
            layer_5s.append(lay5)
            layer_6s.append(lay6)
            print('Saving to Files: ')
            #lay4.tofile('embeddings/lay4.txt')
            #lay5.tofile('embeddings/lay5.txt')
            #lay6.tofile('embeddings/lay6.txt')
            #            np.save('embeddings/lay41.npy', lay4)
            filename = batch.fname.iloc[0]
            save_np_array(lay4, Config.LAYER4 + filename + '.npy')
            save_np_array(lay5, Config.LAYER5 + filename + '.npy')
            save_np_array(lay6, Config.LAYER6 + filename + '.npy')
            #            print('\nLayer 4 Shape: ', load_np_array('embeddings/lay41.npy').shape)
            #            print('\nLayer 4 Shape: ', np.load('embeddings/lay41.npy').shape)
            print('Layer 5 Shape: ', lay5.shape)
            print('Layer 6 Shape: ', lay6.shape)
    print('LAYER4: ', Config.LAYER4)
    ground_truths = []
    predictions = []
    fnames = []

    print('Decoding predictions...')
    bar = progressbar.ProgressBar(max_value=batch_count,
                                  widget=progressbar.AdaptiveETA)

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except:
        num_processes = 1

    # Second pass, decode logits and compute WER and edit distance metrics
    for logits, batch in bar(
            zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
        seq_lengths = batch['features_len'].values.astype(np.int32)
        decoded = ctc_beam_search_decoder_batch(logits,
                                                seq_lengths,
                                                Config.alphabet,
                                                FLAGS.beam_width,
                                                num_processes=num_processes,
                                                scorer=scorer)
        #print('Batch\n', batch)
        ground_truths.extend(
            Config.alphabet.decode(l) for l in batch['transcript'])
        fnames.extend([l for l in batch['fname']])
        #fnames.append(batch['fname'])
        #print(fnames)
        predictions.extend(d[0][1] for d in decoded)

    distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]

    wer, cer, samples = calculate_report(ground_truths, predictions, distances,
                                         losses, fnames)
    print('Sample Lengths: ', len(samples))
    mean_loss = np.mean(losses)

    # Take only the first report_count items
    report_samples = itertools.islice(samples, FLAGS.report_count)
    print(report_samples)
    print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss))
    print('-' * 80)
    count = 0
    for sample in report_samples:
        count += 1
        with open(Config.TEXT + sample.fname + '.txt', 'w') as f:
            f.write(sample.res)
        print("File Name: ", sample.fname)
        print('WER: %f, CER: %f, loss: %f' %
              (sample.wer, sample.distance, sample.loss))
        print(' - src: "%s"' % sample.src)
        print(' - res: "%s"' % sample.res)
        print('-' * 80)
    print('Total Count: ', count)
    return samples
Beispiel #19
0
def initialize_globals():
    c = AttrDict()

    # CPU device
    c.cpu_device = '/cpu:0'

    # Available GPU devices
    c.available_devices = get_available_gpus()

    # If there is no GPU available, we fall back to CPU based operation
    if not c.available_devices:
        c.available_devices = [c.cpu_device]

    # Set default dropout rates
    if FLAGS.dropout_rate2 < 0:
        FLAGS.dropout_rate2 = FLAGS.dropout_rate
    if FLAGS.dropout_rate3 < 0:
        FLAGS.dropout_rate3 = FLAGS.dropout_rate
    if FLAGS.dropout_rate6 < 0:
        FLAGS.dropout_rate6 = FLAGS.dropout_rate

    # Set default checkpoint dir
    if not FLAGS.checkpoint_dir:
        FLAGS.checkpoint_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'checkpoints'))

    if FLAGS.load not in ['last', 'best', 'init', 'auto']:
        FLAGS.load = 'auto'

    # Set default summary dir
    if not FLAGS.summary_dir:
        FLAGS.summary_dir = xdg.save_data_path(
            os.path.join('deepspeech', 'summaries'))

    c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

    # Geometric Constants
    # ===================

    # For an explanation of the meaning of the geometric constants, please refer to
    # doc/Geometry.md

    # Number of MFCC features
    c.n_input = 26  # TODO: Determine this programmatically from the sample rate

    # The number of frames in the context
    c.n_context = 9  # TODO: Determine the optimal value using a validation data set

    # Number of units in hidden layers
    c.n_hidden = FLAGS.n_hidden

    c.n_hidden_1 = c.n_hidden

    c.n_hidden_2 = c.n_hidden

    c.n_hidden_5 = c.n_hidden

    # LSTM cell state dimension
    c.n_cell_dim = c.n_hidden

    # The number of units in the third layer, which feeds in to the LSTM
    c.n_hidden_3 = c.n_cell_dim

    # Units in the sixth layer = number of characters in the target language plus one
    c.n_hidden_6 = c.alphabet.size() + 1  # +1 for CTC blank label

    # Size of audio window in samples
    c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len /
                                                        1000)

    # Stride for feature computations in samples
    c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step /
                                                      1000)

    if FLAGS.one_shot_infer:
        if not os.path.exists(FLAGS.one_shot_infer):
            log_error(
                'Path specified in --one_shot_infer is not a valid file.')
            exit(1)

    ConfigSingleton._config = c  # pylint: disable=protected-access
def do_single_file_inference(checkpoint_dir, input_file_path, layer_wanted,
                             softmax_wanted, save_filename, save_folder,
                             stride_size_s, win_size_s, fea_format,
                             csv_format):
    with tf.Session(config=Config.session_config) as session:
        inputs, outputs, _ = create_inference_graph(
            batch_size=1,
            n_steps=-1,
            layer_wanted=layer_wanted,
            softmax_applied=softmax_wanted)

        # Create a saver using variables from the above newly created graph
        mapping = {
            v.op.name: v
            for v in tf.global_variables()
            if not v.op.name.startswith('previous_state_')
        }
        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
        if not checkpoint:
            log_error(
                'Checkpoint directory ({}) does not contain a valid checkpoint state.'
                .format(checkpoint_dir))
            exit(1)
        checkpoint_path = checkpoint.model_checkpoint_path
        saver.restore(session, checkpoint_path)

        session.run(outputs['initialize_state'])

        # transformation of the audio file
        features = audiofile_to_input_vector(input_file_path, Config.n_input,
                                             Config.n_context)
        #print(features.shape)
        num_strides = len(features) - (Config.n_context * 2)

        # Create a view into the array with overlapping strides of size
        # numcontext (past) + 1 (present) + numcontext (future)
        window_size = 2 * Config.n_context + 1
        features = np.lib.stride_tricks.as_strided(
            features, (num_strides, window_size, Config.n_input),
            (features.strides[0], features.strides[0], features.strides[1]),
            writeable=False)

        # This is not the logits but the ouput of the layer wanted
        logits = session.run(outputs['outputs'],
                             feed_dict={
                                 inputs['input']: [features],
                                 inputs['input_lengths']: [num_strides],
                             })

        logits = np.squeeze(logits)
        if fea_format:
            write_fea_file(logits,
                           save_folder,
                           save_filename,
                           stride_size_s=stride_size_s,
                           win_len_s=win_size_s)
        if csv_format:
            np.savetxt(save_folder + '/' + save_filename + '.csv',
                       logits,
                       delimiter=',')
Beispiel #21
0
def train(server=None):
    r'''
    Trains the network on a given server of a cluster.
    If no server provided, it performs single process training.
    '''

    # Initializing and starting the training coordinator
    coord = TrainingCoordinator(Config.is_chief)
    coord.start()

    # Create a variable to hold the global_step.
    # It will automagically get incremented by the optimizer.
    global_step = tf.Variable(0, trainable=False, name='global_step')

    dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]

    # Reading training set
    train_data = preprocess(FLAGS.train_files.split(','),
                            FLAGS.train_batch_size,
                            Config.n_input,
                            Config.n_context,
                            Config.alphabet,
                            hdf5_cache_path=FLAGS.train_cached_features_path)

    train_set = DataSet(train_data,
                        FLAGS.train_batch_size,
                        limit=FLAGS.limit_train,
                        next_index=lambda i: coord.get_next_index('train'))

    # Reading validation set
    dev_data = preprocess(FLAGS.dev_files.split(','),
                          FLAGS.dev_batch_size,
                          Config.n_input,
                          Config.n_context,
                          Config.alphabet,
                          hdf5_cache_path=FLAGS.dev_cached_features_path)

    dev_set = DataSet(dev_data,
                      FLAGS.dev_batch_size,
                      limit=FLAGS.limit_dev,
                      next_index=lambda i: coord.get_next_index('dev'))

    # Combining all sets to a multi set model feeder
    model_feeder = ModelFeeder(train_set,
                               dev_set,
                               Config.n_input,
                               Config.n_context,
                               Config.alphabet,
                               tower_feeder_count=len(Config.available_devices))

    # Create the optimizer
    optimizer = create_optimizer()

    # Synchronous distributed training is facilitated by a special proxy-optimizer
    if not server is None:
        optimizer = tf.train.SyncReplicasOptimizer(optimizer,
                                                   replicas_to_aggregate=FLAGS.replicas_to_agg,
                                                   total_num_replicas=FLAGS.replicas)

    # Get the data_set specific graph end-points
    gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)

    # Add summaries of all variables and gradients to log
    log_grads_and_vars(avg_tower_gradients)

    # Op to merge all summaries for the summary hook
    merge_all_summaries_op = tf.summary.merge_all()

    # These are saved on every step
    step_summaries_op = tf.summary.merge_all('step_summaries')

    step_summary_writers = {
        'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
        'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
    }

    # Apply gradients to modify the model
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)


    if FLAGS.early_stop is True and not FLAGS.validation_step > 0:
        log_warn('Parameter --validation_step needs to be >0 for early stopping to work')

    class CoordHook(tf.train.SessionRunHook):
        r'''
        Embedded coordination hook-class that will use variables of the
        surrounding Python context.
        '''
        def after_create_session(self, session, coord):
            log_debug('Starting queue runners...')
            model_feeder.start_queue_threads(session, coord)
            log_debug('Queue runners started.')

        def end(self, session):
            # Closing the data_set queues
            log_debug('Closing queues...')
            model_feeder.close_queues(session)
            log_debug('Queues closed.')

            # Telling the ps that we are done
            send_token_to_ps(session)

    # Collecting the hooks
    hooks = [CoordHook()]

    # Hook to handle initialization and queues for sync replicas.
    if not server is None:
        hooks.append(optimizer.make_session_run_hook(Config.is_chief))

    # Hook to save TensorBoard summaries
    if FLAGS.summary_secs > 0:
        hooks.append(tf.train.SummarySaverHook(save_secs=FLAGS.summary_secs, output_dir=FLAGS.summary_dir, summary_op=merge_all_summaries_op))

    # Hook wih number of checkpoint files to save in checkpoint_dir
    if FLAGS.train and FLAGS.max_to_keep > 0:
        saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
        hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir, save_secs=FLAGS.checkpoint_secs, saver=saver))

    no_dropout_feed_dict = {
        dropout_rates[0]: 0.,
        dropout_rates[1]: 0.,
        dropout_rates[2]: 0.,
        dropout_rates[3]: 0.,
        dropout_rates[4]: 0.,
        dropout_rates[5]: 0.,
    }

    # Progress Bar
    def update_progressbar(set_name):
        if not hasattr(update_progressbar, 'current_set_name'):
            update_progressbar.current_set_name = None

        if (update_progressbar.current_set_name != set_name or
            update_progressbar.current_job_index == update_progressbar.total_jobs):

            # finish prev pbar if it exists
            if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar:
                update_progressbar.pbar.finish()

            update_progressbar.total_jobs = None
            update_progressbar.current_job_index = 0

            current_epoch = coord._epoch-1

            if set_name == "train":
                log_info('Training epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_train
            else:
                log_info('Validating epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_dev

            # recreate pbar
            update_progressbar.pbar = progressbar.ProgressBar(max_value=update_progressbar.total_jobs,
                                                              redirect_stdout=True).start()

            update_progressbar.current_set_name = set_name

        if update_progressbar.pbar:
            update_progressbar.pbar.update(update_progressbar.current_job_index+1, force=True)

        update_progressbar.current_job_index += 1

    # Initialize update_progressbar()'s child fields to safe values
    update_progressbar.pbar = None

    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    try:
        with tf.train.MonitoredTrainingSession(master='' if server is None else server.target,
                                               is_chief=Config.is_chief,
                                               hooks=hooks,
                                               checkpoint_dir=FLAGS.checkpoint_dir,
                                               save_checkpoint_secs=None, # already taken care of by a hook
                                               log_step_count_steps=0, # disable logging of steps/s to avoid TF warning in validation sets
                                               config=Config.session_config) as session:
            tf.get_default_graph().finalize()

            try:
                if Config.is_chief:
                    # Retrieving global_step from the (potentially restored) model
                    model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train)
                    step = session.run(global_step, feed_dict=no_dropout_feed_dict)
                    coord.start_coordination(model_feeder, step)

                # Get the first job
                job = coord.get_job()

                while job and not session.should_stop():
                    log_debug('Computing %s...' % job)

                    is_train = job.set_name == 'train'

                    # The feed_dict (mainly for switching between queues)
                    if is_train:
                        feed_dict = {
                            dropout_rates[0]: FLAGS.dropout_rate,
                            dropout_rates[1]: FLAGS.dropout_rate2,
                            dropout_rates[2]: FLAGS.dropout_rate3,
                            dropout_rates[3]: FLAGS.dropout_rate4,
                            dropout_rates[4]: FLAGS.dropout_rate5,
                            dropout_rates[5]: FLAGS.dropout_rate6,
                        }
                    else:
                        feed_dict = no_dropout_feed_dict

                    # Sets the current data_set for the respective placeholder in feed_dict
                    model_feeder.set_data_set(feed_dict, getattr(model_feeder, job.set_name))

                    # Initialize loss aggregator
                    total_loss = 0.0

                    # Setting the training operation in case of training requested
                    train_op = apply_gradient_op if is_train else []

                    # So far the only extra parameter is the feed_dict
                    extra_params = { 'feed_dict': feed_dict }

                    step_summary_writer = step_summary_writers.get(job.set_name)

                    # Loop over the batches
                    for job_step in range(job.steps):
                        if session.should_stop():
                            break

                        log_debug('Starting batch...')
                        # Compute the batch
                        _, current_step, batch_loss, step_summary = session.run([train_op, global_step, loss, step_summaries_op], **extra_params)

                        # Log step summaries
                        step_summary_writer.add_summary(step_summary, current_step)

                        # Uncomment the next line for debugging race conditions / distributed TF
                        log_debug('Finished batch step %d.' % current_step)

                        # Add batch to loss
                        total_loss += batch_loss

                    # Gathering job results
                    job.loss = total_loss / job.steps

                    # Display progressbar
                    if FLAGS.show_progressbar:
                        update_progressbar(job.set_name)

                    # Send the current job to coordinator and receive the next one
                    log_debug('Sending %s...' % job)
                    job = coord.next_job(job)

                if update_progressbar.pbar:
                    update_progressbar.pbar.finish()

            except Exception as e:
                log_error(str(e))
                traceback.print_exc()
                # Calling all hook's end() methods to end blocking calls
                for hook in hooks:
                    hook.end(session)
                # Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit.
                # A rather graceful way to do this is by stopping the ps.
                # Only one party can send it w/o failing.
                if Config.is_chief:
                    send_token_to_ps(session, kill=True)
                sys.exit(1)

        log_debug('Session closed.')

    except tf.errors.InvalidArgumentError as e:
        log_error(str(e))
        log_error('The checkpoint in {0} does not match the shapes of the model.'
                  ' Did you change alphabet.txt or the --n_hidden parameter'
                  ' between train runs using the same checkpoint dir? Try moving'
                  ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
        sys.exit(1)

    # Stopping the coordinator
    coord.stop()
Beispiel #22
0
def evaluate(test_csvs, create_model, try_loading):
    scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path,
                    FLAGS.lm_trie_path, Config.alphabet)

    test_set = create_dataset(test_csvs,
                              batch_size=FLAGS.test_batch_size,
                              cache_path=FLAGS.test_cached_features_path)
    it = test_set.make_one_shot_iterator()

    (batch_x, batch_x_len), batch_y = it.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))

    loss = tf.nn.ctc_loss(labels=batch_y,
                          inputs=logits,
                          sequence_length=batch_x_len)

    global_step = tf.train.get_or_create_global_step()

    with tf.Session(config=Config.session_config) as session:
        # Create a saver using variables from the above newly created graph
        saver = tf.train.Saver()

        # Restore variables from training checkpoint
        loaded = try_loading(session, saver, 'best_dev_checkpoint',
                             'best validation')
        if not loaded:
            loaded = try_loading(session, saver, 'checkpoint', 'most recent')
        if not loaded:
            log_error(
                'Checkpoint directory ({}) does not contain a valid checkpoint state.'
                .format(FLAGS.checkpoint_dir))
            exit(1)

        logitses = []
        losses = []
        seq_lengths = []
        ground_truths = []

        print('Computing acoustic model predictions...')
        bar = progressbar.ProgressBar(widgets=[
            'Steps: ',
            progressbar.Counter(), ' | ',
            progressbar.Timer()
        ])

        step_count = 0

        # First pass, compute losses and transposed logits for decoding
        while True:
            try:
                logits, loss_, lengths, transcripts = session.run(
                    [transposed, loss, batch_x_len, batch_y])
            except tf.errors.OutOfRangeError:
                break

            step_count += 1
            bar.update(step_count)

            logitses.append(logits)
            losses.extend(loss_)
            seq_lengths.append(lengths)
            ground_truths.extend(
                sparse_tensor_value_to_texts(transcripts, Config.alphabet))

    bar.finish()

    predictions = []

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except:
        num_processes = 1

    print('Decoding predictions...')
    bar = progressbar.ProgressBar(max_value=step_count,
                                  widget=progressbar.AdaptiveETA)

    # Second pass, decode logits and compute WER and edit distance metrics
    for logits, seq_length in bar(zip(logitses, seq_lengths)):
        decoded = ctc_beam_search_decoder_batch(logits,
                                                seq_length,
                                                Config.alphabet,
                                                FLAGS.beam_width,
                                                num_processes=num_processes,
                                                scorer=scorer)
        predictions.extend(d[0][1] for d in decoded)

    distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]

    wer, cer, samples = calculate_report(ground_truths, predictions, distances,
                                         losses)
    mean_loss = np.mean(losses)

    # Take only the first report_count items
    report_samples = itertools.islice(samples, FLAGS.report_count)

    print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss))
    print('-' * 80)
    for sample in report_samples:
        print('WER: %f, CER: %f, loss: %f' %
              (sample.wer, sample.distance, sample.loss))
        print(' - src: "%s"' % sample.src)
        print(' - res: "%s"' % sample.res)
        print('-' * 80)

    return samples
Beispiel #23
0
def evaluate(test_data, inference_graph, alphabet):
    scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
                    FLAGS.lm_binary_path, FLAGS.lm_trie_path,
                    Config.alphabet)


    def create_windows(features):
        num_strides = len(features) - (Config.n_context * 2)

        # Create a view into the array with overlapping strides of size
        # numcontext (past) + 1 (present) + numcontext (future)
        window_size = 2*Config.n_context+1
        features = np.lib.stride_tricks.as_strided(
            features,
            (num_strides, window_size, Config.n_input),
            (features.strides[0], features.strides[0], features.strides[1]),
            writeable=False)

        return features

    # Create overlapping windows over the features
    test_data['features'] = test_data['features'].apply(create_windows)

    with tf.Session(config=Config.session_config) as session:
        inputs, outputs, layers = inference_graph

        # Transpose to batch major for decoder
        transposed = tf.transpose(outputs['outputs'], [1, 0, 2])

        labels_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size, None], name="labels")
        label_lengths_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size], name="label_lengths")

        sparse_labels = tf.cast(ctc_label_dense_to_sparse(labels_ph, label_lengths_ph, FLAGS.test_batch_size), tf.int32)
        loss = tf.nn.ctc_loss(labels=sparse_labels,
                              inputs=layers['raw_logits'],
                              sequence_length=inputs['input_lengths'])

        # Create a saver using variables from the above newly created graph
        mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        if FLAGS.checkpoint_dir is not None:
            checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if not checkpoint:
                log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
                exit(1)

            checkpoint_path = checkpoint.model_checkpoint_path
            saver.restore(session, checkpoint_path)

        logitses = []
        losses = []

        print('Computing acoustic model predictions...')
        batch_count = len(test_data) // FLAGS.test_batch_size
        bar = progressbar.ProgressBar(max_value=batch_count,
                                      widget=progressbar.AdaptiveETA)

        # First pass, compute losses and transposed logits for decoding
        for batch in bar(split_data(test_data, FLAGS.test_batch_size)):
            session.run(outputs['initialize_state'])

            features = pad_to_dense(batch['features'].values)
            features_len = batch['features_len'].values
            labels = pad_to_dense(batch['transcript'].values)
            label_lengths = batch['transcript_len'].values

            logits, loss_ = session.run([transposed, loss], feed_dict={
                inputs['input']: features,
                inputs['input_lengths']: features_len,
                labels_ph: labels,
                label_lengths_ph: label_lengths
            })

            logitses.append(logits)
            losses.extend(loss_)

    ground_truths = []
    predictions = []

    print('Decoding predictions...')
    bar = progressbar.ProgressBar(max_value=batch_count,
                                  widget=progressbar.AdaptiveETA)

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except:
        num_processes = 1

    # Second pass, decode logits and compute WER and edit distance metrics
    for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
        seq_lengths = batch['features_len'].values.astype(np.int32)
        decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width,
                                                num_processes=num_processes, scorer=scorer)

        ground_truths.extend(alphabet.decode(l) for l in batch['transcript'])
        predictions.extend(d[0][1] for d in decoded)

    distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]

    wer, samples = calculate_report(ground_truths, predictions, distances, losses)
    mean_edit_distance = np.mean(distances)
    mean_loss = np.mean(losses)

    # Take only the first report_count items
    report_samples = itertools.islice(samples, FLAGS.report_count)

    print('Test - WER: %f, CER: %f, loss: %f' %
          (wer, mean_edit_distance, mean_loss))
    print('-' * 80)
    for sample in report_samples:
        print('WER: %f, CER: %f, loss: %f' %
              (sample.wer, sample.distance, sample.loss))
        print(' - src: "%s"' % sample.src)
        print(' - res: "%s"' % sample.res)
        print('-' * 80)

    return samples
Beispiel #24
0
def fail(message, code=1):
    log_error(message)
    sys.exit(code)
Beispiel #25
0
def train():
    do_cache_dataset = True

    # pylint: disable=too-many-boolean-expressions
    if (FLAGS.data_aug_features_multiplicative > 0 or
            FLAGS.data_aug_features_additive > 0 or
            FLAGS.augmentation_spec_dropout_keeprate < 1 or
            FLAGS.augmentation_freq_and_time_masking or
            FLAGS.augmentation_pitch_and_tempo_scaling or
            FLAGS.augmentation_speed_up_std > 0 or
            FLAGS.augmentation_sparse_warp):
        do_cache_dataset = False

    # Create training and validation datasets
    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               enable_cache=FLAGS.feature_cache and do_cache_dataset,
                               cache_path=FLAGS.feature_cache,
                               train_phase=True)

    iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
                                                 tfv1.data.get_output_shapes(train_set),
                                                 output_classes=tfv1.data.get_output_classes(train_set))

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_csvs = FLAGS.dev_files.split(',')
        dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
        dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

    # The transfer learning approach here need us to supply the layers which we
    # want to exclude from the source model.
    # Say we want to exclude all layers except for the first one, we can use this:
    #
    #    drop_source_layers=['2', '3', 'lstm', '5', '6']
    #
    # If we want to use all layers from the source model except the last one, we use this:
    #
    #    drop_source_layers=['6']
    #

    if FLAGS.load == "transfer":
        drop_source_layers = ['2', '3', 'lstm', '5', '6'][-int(FLAGS.drop_source_layers):]
    else:
        drop_source_layers=None
    
    # Dropout
    dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {
        rate: 0. for rate in dropout_rates
    }

    # Building the graph
    optimizer = create_optimizer()

    # Enable mixed precision training
    if FLAGS.automatic_mixed_precision:
        log_info('Enabling automatic mixed precision training.')
        optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)

    gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates, drop_source_layers)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)

    # global_step is automagically incremented by the optimizer
    global_step = tfv1.train.get_or_create_global_step()
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

    # Summaries
    step_summaries_op = tfv1.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
        'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')

    best_dev_saver = tfv1.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')

    # Save flags next to checkpoints
    os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)

    flags_file = os.path.join(FLAGS.checkpoint_dir, 'flags.txt')
    with open(flags_file, 'w') as fout:
        fout.write(FLAGS.flags_into_string())

    initializer = tfv1.global_variables_initializer()

    with tfv1.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        # Loading or initializing
        loaded = False

        # Initialize training from a CuDNN RNN checkpoint
        if FLAGS.cudnn_checkpoint:
            if FLAGS.use_cudnn_rnn:
                log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
                          'was specified. The --cudnn_checkpoint flag is only '
                          'needed when converting a CuDNN RNN checkpoint to '
                          'a CPU-capable graph. If your system is capable of '
                          'using CuDNN RNN, you can just specify the CuDNN RNN '
                          'checkpoint normally with --checkpoint_dir.')
                sys.exit(1)

            log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
            ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
            missing_variables = []

            # Load compatible variables from checkpoint
            for v in tfv1.global_variables():
                try:
                    v.load(ckpt.get_tensor(v.op.name), session=session)
                except tf.errors.NotFoundError:
                    missing_variables.append(v)

            # Check that the only missing variables are the Adam moment tensors
            if any('Adam' not in v.op.name for v in missing_variables):
                log_error('Tried to load a CuDNN RNN checkpoint but there were '
                          'more missing variables than just the Adam moment '
                          'tensors.')
                sys.exit(1)

            # Initialize Adam moment tensors from scratch to allow use of CuDNN
            # RNN checkpoints.
            log_info('Initializing missing Adam moment tensors.')
            init_op = tfv1.variables_initializer(missing_variables)
            session.run(init_op)
            loaded = True
			
        

        if not loaded and FLAGS.load in ['auto', 'last']:
            #tf.initialize_all_variables().run()
            tfv1.get_default_graph().finalize()			
            loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
        if not loaded and FLAGS.load in ['auto', 'best']:
            #tf.initialize_all_variables().run()
            tfv1.get_default_graph().finalize()
            loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
        if not loaded : 
            if FLAGS.load == "transfer":
                if FLAGS.source_model_checkpoint_dir:
                    print('Initializing model from', FLAGS.source_model_checkpoint_dir)
                    ckpt = tfv1.train.load_checkpoint(FLAGS.source_model_checkpoint_dir)
                    variables = list(ckpt.get_variable_to_shape_map().keys())
                    print('variable', variables)
                    print('global', tf.global_variables())				
                    # Load desired source variables
                    missing_variables2 = []				
                    for v in tf.global_variables():
                        if not any(layer in v.op.name for layer in drop_source_layers):
                            print('Loading', v.op.name)
                            try:						
                                v.load(ckpt.get_tensor(v.op.name), session=session)
                                print('OK')
                            except tf.errors.NotFoundError:
                                missing_variables2.append(v)
                                print('KO')
                            except ValueError:
                                #missing_variables2.append(v)
                                print('KO for valueError')						
                    print('missing_variables =', missing_variables2)					
                    # Initialize all variables needed for DS, but not loaded from ckpt
                    
                    init_op = tfv1.variables_initializer(
                        [v for v in tf.global_variables()
                        if any(layer in v.op.name
                                for layer in drop_source_layers)
                        ] + missing_variables2)
                    tfv1.get_default_graph().finalize()
                    session.run(init_op)
                   
			
            elif FLAGS.load in ['auto', 'init']:
                log_info('Initializing variables...')
                tfv1.get_default_graph().finalize()
                session.run(initializer)
            else:
                log_error('Unable to load %s model from specified checkpoint dir'
                        ' - consider using load option "auto" or "init".' % FLAGS.load)
                sys.exit(1)


        def run_set(set_name, epoch, init_op, dataset=None):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            # Setup progress bar
            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data['mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

            prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
            widgets = [' | ', progressbar.widgets.Timer(),
                       ' | Steps: ', progressbar.widgets.Counter(),
                       ' | ', LossWidget()]
            suffix = ' | Dataset: {}'.format(dataset) if dataset else None
            pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, problem_files, step_summary = \
                        session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
                                    feed_dict=feed_dict)
                except tf.errors.InvalidArgumentError as err:
                    if FLAGS.augmentation_sparse_warp:
                        log_info("Ignoring sparse warp error: {}".format(err))
                        continue
                    else:
                        raise
                except tf.errors.OutOfRangeError:
                    break

                if problem_files.size > 0:
                    problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
                    log_error('The following files caused an infinite (or NaN) '
                              'loss: {}'.format(','.join(problem_files)))

                total_loss += batch_loss
                step_count += 1

                pbar.update(step_count)

                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
                    checkpoint_time = time.time()

            pbar.finish()
            mean_loss = total_loss / step_count if step_count > 0 else 0.0
            return mean_loss, step_count

        log_info('STARTING Optimization')
        train_start_time = datetime.utcnow()
        best_dev_loss = float('inf')
        dev_losses = []
        try:
            for epoch in range(FLAGS.epochs):
                # Training
                log_progress('Training epoch %d...' % epoch)
                train_loss, _ = run_set('train', epoch, train_init_op)
                log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
                checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

                if FLAGS.dev_files:
                    # Validation
                    dev_loss = 0.0
                    total_steps = 0
                    for csv, init_op in zip(dev_csvs, dev_init_ops):
                        log_progress('Validating epoch %d on %s...' % (epoch, csv))
                        set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
                        dev_loss += set_loss * steps
                        total_steps += steps
                        log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
                    dev_loss = dev_loss / total_steps

                    dev_losses.append(dev_loss)

                    if dev_loss < best_dev_loss:
                        best_dev_loss = dev_loss
                        save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
                        log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))

                    # Early stopping
                    if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
                        mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                        std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                        dev_losses = dev_losses[-FLAGS.es_steps:]
                        log_debug('Checking for early stopping (last %d steps) validation loss: '
                                  '%f, with standard deviation: %f and mean: %f' %
                                  (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
                        if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                           (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                            log_info('Early stop triggered as (for last %d steps) validation loss:'
                                     ' %f with standard deviation: %f and mean: %f' %
                                     (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
                            break
        except KeyboardInterrupt:
            pass
        log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
    log_debug('Session closed.')
Beispiel #26
0
def train():
    # Create training and validation datasets
    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               cache_path=FLAGS.train_cached_features_path)

    iterator = tf.data.Iterator.from_structure(train_set.output_types,
                                               train_set.output_shapes,
                                               output_classes=train_set.output_classes)

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_set = create_dataset(FLAGS.dev_files.split(','),
                                 batch_size=FLAGS.dev_batch_size,
                                 cache_path=FLAGS.dev_cached_features_path)
        dev_init_op = iterator.make_initializer(dev_set)

    # Dropout
    dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {
        rate: 0. for rate in dropout_rates
    }

    # Building the graph
    optimizer = create_optimizer()
    gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)
    log_grads_and_vars(avg_tower_gradients)

    # global_step is automagically incremented by the optimizer
    global_step = tf.train.get_or_create_global_step()
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

    # Summaries
    step_summaries_op = tf.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
        'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
    }

    # Checkpointing
    checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
    checkpoint_filename = 'checkpoint'

    best_dev_saver = tf.train.Saver(max_to_keep=1)
    best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
    best_dev_filename = 'best_dev_checkpoint'

    initializer = tf.global_variables_initializer()

    with tf.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        tf.get_default_graph().finalize()

        # Loading or initializing
        loaded = False
        if FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
        if not loaded:
            if FLAGS.load in ['auto', 'init']:
                log_info('Initializing variables...')
                session.run(initializer)
            else:
                log_error('Unable to load %s model from specified checkpoint dir'
                          ' - consider using load option "auto" or "init".' % FLAGS.load)
                sys.exit(1)

        def run_set(set_name, init_op):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data['mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

            if FLAGS.show_progressbar:
                pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
                                                        ' | ', progressbar.widgets.Timer(),
                                                        ' | Steps: ', progressbar.widgets.Counter(),
                                                        ' | ', LossWidget()])
                pbar.start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, step_summary = \
                        session.run([train_op, global_step, loss, step_summaries_op],
                                    feed_dict=feed_dict)
                except tf.errors.OutOfRangeError:
                    break

                total_loss += batch_loss
                step_count += 1

                if FLAGS.show_progressbar:
                    pbar.update(step_count)

                step_summary_writer.add_summary(step_summary, current_step)

                if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
                    checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
                    checkpoint_time = time.time()

            if FLAGS.show_progressbar:
                pbar.finish()

            return total_loss / step_count

        log_info('STARTING Optimization')
        best_dev_loss = float('inf')
        dev_losses = []
        try:
            for epoch in range(FLAGS.epochs):
                # Training
                if not FLAGS.show_progressbar:
                    log_info('Training epoch %d...' % epoch)
                train_loss = run_set('train', train_init_op)
                if not FLAGS.show_progressbar:
                    log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
                checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

                if FLAGS.dev_files:
                    # Validation
                    if not FLAGS.show_progressbar:
                        log_info('Validating epoch %d...' % epoch)
                    dev_loss = run_set('dev', dev_init_op)
                    if not FLAGS.show_progressbar:
                        log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
                    dev_losses.append(dev_loss)

                    if dev_loss < best_dev_loss:
                        best_dev_loss = dev_loss
                        save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
                        log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))

                    # Early stopping
                    if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
                        mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
                        std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
                        dev_losses = dev_losses[-FLAGS.es_steps:]
                        log_debug('Checking for early stopping (last %d steps) validation loss: '
                                  '%f, with standard deviation: %f and mean: %f' %
                                  (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
                        if dev_losses[-1] > np.max(dev_losses[:-1]) or \
                           (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
                            log_info('Early stop triggered as (for last %d steps) validation loss:'
                                     ' %f with standard deviation: %f and mean: %f' %
                                     (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
                            break
        except KeyboardInterrupt:
            pass
    log_debug('Session closed.')
Beispiel #27
0
def export():
    r'''
    Restores the trained variables into a simpler graph that will be exported for serving.
    '''
    log_info('Exporting the model...')
    from tensorflow.python.framework.ops import Tensor, Operation

    inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)

    graph_version = int(file_relative_read('GRAPH_VERSION').strip())
    assert graph_version > 0

    outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
    outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
    outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
    outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
    outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')

    if FLAGS.export_language:
        outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')

    output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
    output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
    output_names = ",".join(output_names_tensors + output_names_ops)

    # Create a saver using variables from the above newly created graph
    saver = tfv1.train.Saver()

    # Restore variables from training checkpoint
    checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    checkpoint_path = checkpoint.model_checkpoint_path

    output_filename = 'output_graph.pb'
    if FLAGS.remove_export:
        if os.path.isdir(FLAGS.export_dir):
            log_info('Removing old export')
            shutil.rmtree(FLAGS.export_dir)
    try:
        output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

        if not os.path.isdir(FLAGS.export_dir):
            os.makedirs(FLAGS.export_dir)

        def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
            frozen = freeze_graph.freeze_graph_with_def_protos(
                input_graph_def=tfv1.get_default_graph().as_graph_def(),
                input_saver_def=saver.as_saver_def(),
                input_checkpoint=checkpoint_path,
                output_node_names=output_node_names,
                restore_op_name=None,
                filename_tensor_name=None,
                output_graph=output_file,
                clear_devices=False,
                variable_names_blacklist=variables_blacklist,
                initializer_nodes='')

            input_node_names = []
            return strip_unused_lib.strip_unused(
                input_graph_def=frozen,
                input_node_names=input_node_names,
                output_node_names=output_node_names.split(','),
                placeholder_type_enum=tf.float32.as_datatype_enum)

        frozen_graph = do_graph_freeze(output_node_names=output_names)

        if not FLAGS.export_tflite:
            with open(output_graph_path, 'wb') as fout:
                fout.write(frozen_graph.SerializeToString())
        else:
            output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))

            converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
            converter.optimizations = [ tf.lite.Optimize.DEFAULT ]
            # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
            converter.allow_custom_ops = True
            tflite_model = converter.convert()

            with open(output_tflite_path, 'wb') as fout:
                fout.write(tflite_model)

            log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))

        log_info('Models exported at %s' % (FLAGS.export_dir))
    except RuntimeError as e:
        log_error(str(e))
Beispiel #28
0
def export():
    r'''
    Restores the trained variables into a simpler graph that will be exported for serving.
    '''
    log_info('Exporting the model...')
    from tensorflow.python.framework.ops import Tensor, Operation

    inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
    output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
    output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
    output_names = ",".join(output_names_tensors + output_names_ops)

    if not FLAGS.export_tflite:
        mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
    else:
        # Create a saver using variables from the above newly created graph
        def fixup(name):
            if name.startswith('rnn/lstm_cell/'):
                return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
            return name

        mapping = {fixup(v.op.name): v for v in tf.global_variables()}

    saver = tf.train.Saver(mapping)

    # Restore variables from training checkpoint
    checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    checkpoint_path = checkpoint.model_checkpoint_path

    output_filename = 'output_graph.pb'
    if FLAGS.remove_export:
        if os.path.isdir(FLAGS.export_dir):
            log_info('Removing old export')
            shutil.rmtree(FLAGS.export_dir)
    try:
        output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

        if not os.path.isdir(FLAGS.export_dir):
            os.makedirs(FLAGS.export_dir)

        def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
            return freeze_graph.freeze_graph_with_def_protos(
                input_graph_def=tf.get_default_graph().as_graph_def(),
                input_saver_def=saver.as_saver_def(),
                input_checkpoint=checkpoint_path,
                output_node_names=output_node_names,
                restore_op_name=None,
                filename_tensor_name=None,
                output_graph=output_file,
                clear_devices=False,
                variable_names_blacklist=variables_blacklist,
                initializer_nodes='')

        if not FLAGS.export_tflite:
            frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
            frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())

            # Add a no-op node to the graph with metadata information to be loaded by the native client
            metadata = frozen_graph.node.add()
            metadata.name = 'model_metadata'
            metadata.op = 'NoOp'
            metadata.attr['sample_rate'].i = FLAGS.audio_sample_rate
            metadata.attr['feature_win_len'].i = FLAGS.feature_win_len
            metadata.attr['feature_win_step'].i = FLAGS.feature_win_step
            if FLAGS.export_language:
                metadata.attr['language'].s = FLAGS.export_language.encode('ascii')

            with open(output_graph_path, 'wb') as fout:
                fout.write(frozen_graph.SerializeToString())
        else:
            frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
            output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))

            converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
            converter.post_training_quantize = True
            # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
            converter.allow_custom_ops = True
            tflite_model = converter.convert()

            with open(output_tflite_path, 'wb') as fout:
                fout.write(tflite_model)

            log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))

        log_info('Models exported at %s' % (FLAGS.export_dir))
    except RuntimeError as e:
        log_error(str(e))
Beispiel #29
0
def initialize_globals():
    c = AttrDict()

    # CPU device
    c.cpu_device = '/cpu:0'

    # Available GPU devices
    c.available_devices = get_available_gpus()

    # If there is no GPU available, we fall back to CPU based operation
    if not c.available_devices:
        c.available_devices = [c.cpu_device]

    # Set default dropout rates
    if FLAGS.dropout_rate2 < 0:
        FLAGS.dropout_rate2 = FLAGS.dropout_rate
    if FLAGS.dropout_rate3 < 0:
        FLAGS.dropout_rate3 = FLAGS.dropout_rate
    if FLAGS.dropout_rate6 < 0:
        FLAGS.dropout_rate6 = FLAGS.dropout_rate

    # Set default checkpoint dir
    if not FLAGS.checkpoint_dir:
        FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))

    if FLAGS.load not in ['last', 'best', 'init', 'auto']:
        FLAGS.load = 'auto'

    # Set default summary dir
    if not FLAGS.summary_dir:
        FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech', 'summaries'))

    # Standard session configuration that'll be used for all new sessions.
    c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
                                      inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
                                      intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)

    c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

    # Geometric Constants
    # ===================

    # For an explanation of the meaning of the geometric constants, please refer to
    # doc/Geometry.md

    # Number of MFCC features
    c.n_input = 26 # TODO: Determine this programmatically from the sample rate

    # The number of frames in the context
    c.n_context = 9 # TODO: Determine the optimal value using a validation data set

    # Number of units in hidden layers
    c.n_hidden = FLAGS.n_hidden

    c.n_hidden_1 = c.n_hidden

    c.n_hidden_2 = c.n_hidden

    c.n_hidden_5 = c.n_hidden

    # LSTM cell state dimension
    c.n_cell_dim = c.n_hidden

    # The number of units in the third layer, which feeds in to the LSTM
    c.n_hidden_3 = c.n_cell_dim

    # Units in the sixth layer = number of characters in the target language plus one
    c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label

    # Size of audio window in samples
    c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000)

    # Stride for feature computations in samples
    c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)

    if FLAGS.one_shot_infer:
        if not os.path.exists(FLAGS.one_shot_infer):
            log_error('Path specified in --one_shot_infer is not a valid file.')
            exit(1)

    ConfigSingleton._config = c # pylint: disable=protected-access
Beispiel #30
0
def export():
    r'''
    Restores the trained variables into a simpler graph that will be exported for serving.
    '''
    log_info('Exporting the model...')
    with tf.device('/cpu:0'):
        from tensorflow.python.framework.ops import Tensor, Operation

        tf.reset_default_graph()
        session = tf.Session(config=Config.session_config)

        inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
        input_names = ",".join(tensor.op.name for tensor in inputs.values())
        output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ]
        output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ]
        output_names = ",".join(output_names_tensors + output_names_ops)
        input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())

        if not FLAGS.export_tflite:
            mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
        else:
            # Create a saver using variables from the above newly created graph
            def fixup(name):
                if name.startswith('rnn/lstm_cell/'):
                    return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
                return name

            mapping = {fixup(v.op.name): v for v in tf.global_variables()}

        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        checkpoint_path = checkpoint.model_checkpoint_path

        output_filename = 'output_graph.pb'
        if FLAGS.remove_export:
            if os.path.isdir(FLAGS.export_dir):
                log_info('Removing old export')
                shutil.rmtree(FLAGS.export_dir)
        try:
            output_graph_path = os.path.join(FLAGS.export_dir, output_filename)

            if not os.path.isdir(FLAGS.export_dir):
                os.makedirs(FLAGS.export_dir)

            def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
                freeze_graph.freeze_graph_with_def_protos(
                    input_graph_def=session.graph_def,
                    input_saver_def=saver.as_saver_def(),
                    input_checkpoint=checkpoint_path,
                    output_node_names=output_node_names,
                    restore_op_name=None,
                    filename_tensor_name=None,
                    output_graph=output_file,
                    clear_devices=False,
                    variable_names_blacklist=variables_blacklist,
                    initializer_nodes='')

            if not FLAGS.export_tflite:
                do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
            else:
                temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
                os.close(temp_fd)
                do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
                output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
                class TFLiteFlags():
                    def __init__(self):
                        self.graph_def_file = temp_freeze
                        self.inference_type = 'FLOAT'
                        self.input_arrays   = input_names
                        self.input_shapes   = input_shapes
                        self.output_arrays  = output_names
                        self.output_file    = output_tflite_path
                        self.output_format  = 'TFLITE'

                        default_empty = [
                            'inference_input_type',
                            'mean_values',
                            'default_ranges_min', 'default_ranges_max',
                            'drop_control_dependency',
                            'reorder_across_fake_quant',
                            'change_concat_input_ranges',
                            'allow_custom_ops',
                            'converter_mode',
                            'post_training_quantize',
                            'dump_graphviz_dir',
                            'dump_graphviz_video'
                        ]
                        for e in default_empty:
                            self.__dict__[e] = None

                flags = TFLiteFlags()
                tflite_convert._convert_model(flags)
                os.unlink(temp_freeze)
                log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))

            log_info('Models exported at %s' % (FLAGS.export_dir))
        except RuntimeError as e:
            log_error(str(e))
Beispiel #31
0
def train(server=None):
    r'''
    Trains the network on a given server of a cluster.
    If no server provided, it performs single process training.
    '''

    # Initializing and starting the training coordinator
    coord = TrainingCoordinator(Config.is_chief)
    coord.start()

    # Create a variable to hold the global_step.
    # It will automagically get incremented by the optimizer.
    global_step = tf.Variable(0, trainable=False, name='global_step')

    dropout_rates = [
        tf.placeholder(tf.float32, name='dropout_{}'.format(i))
        for i in range(6)
    ]

    # Reading training set
    train_data = preprocess(FLAGS.train_files.split(','),
                            FLAGS.train_batch_size,
                            Config.n_input,
                            Config.n_context,
                            Config.alphabet,
                            hdf5_cache_path=FLAGS.train_cached_features_path)

    train_set = DataSet(train_data,
                        FLAGS.train_batch_size,
                        limit=FLAGS.limit_train,
                        next_index=lambda i: coord.get_next_index('train'))

    # Reading validation set
    dev_data = preprocess(FLAGS.dev_files.split(','),
                          FLAGS.dev_batch_size,
                          Config.n_input,
                          Config.n_context,
                          Config.alphabet,
                          hdf5_cache_path=FLAGS.dev_cached_features_path)

    dev_set = DataSet(dev_data,
                      FLAGS.dev_batch_size,
                      limit=FLAGS.limit_dev,
                      next_index=lambda i: coord.get_next_index('dev'))

    # Combining all sets to a multi set model feeder
    model_feeder = ModelFeeder(train_set,
                               dev_set,
                               Config.n_input,
                               Config.n_context,
                               Config.alphabet,
                               tower_feeder_count=len(
                                   Config.available_devices))

    # Create the optimizer
    optimizer = create_optimizer()

    # Synchronous distributed training is facilitated by a special proxy-optimizer
    if not server is None:
        optimizer = tf.train.SyncReplicasOptimizer(
            optimizer,
            replicas_to_aggregate=FLAGS.replicas_to_agg,
            total_num_replicas=FLAGS.replicas)

    # Get the data_set specific graph end-points
    gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)

    # Add summaries of all variables and gradients to log
    log_grads_and_vars(avg_tower_gradients)

    # Op to merge all summaries for the summary hook
    merge_all_summaries_op = tf.summary.merge_all()

    # These are saved on every step
    step_summaries_op = tf.summary.merge_all('step_summaries')

    step_summary_writers = {
        'train':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'),
                              max_queue=120),
        'dev':
        tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'),
                              max_queue=120)
    }

    # Apply gradients to modify the model
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients,
                                                  global_step=global_step)

    if FLAGS.early_stop is True and not FLAGS.validation_step > 0:
        log_warn(
            'Parameter --validation_step needs to be >0 for early stopping to work'
        )

    class CoordHook(tf.train.SessionRunHook):
        r'''
        Embedded coordination hook-class that will use variables of the
        surrounding Python context.
        '''
        def after_create_session(self, session, coord):
            log_debug('Starting queue runners...')
            model_feeder.start_queue_threads(session, coord)
            log_debug('Queue runners started.')

        def end(self, session):
            # Closing the data_set queues
            log_debug('Closing queues...')
            model_feeder.close_queues(session)
            log_debug('Queues closed.')

            # Telling the ps that we are done
            send_token_to_ps(session)

    # Collecting the hooks
    hooks = [CoordHook()]

    # Hook to handle initialization and queues for sync replicas.
    if not server is None:
        hooks.append(optimizer.make_session_run_hook(Config.is_chief))

    # Hook to save TensorBoard summaries
    if FLAGS.summary_secs > 0:
        hooks.append(
            tf.train.SummarySaverHook(save_secs=FLAGS.summary_secs,
                                      output_dir=FLAGS.summary_dir,
                                      summary_op=merge_all_summaries_op))

    # Hook wih number of checkpoint files to save in checkpoint_dir
    if FLAGS.train and FLAGS.max_to_keep > 0:
        saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
        hooks.append(
            tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir,
                                         save_secs=FLAGS.checkpoint_secs,
                                         saver=saver))

    no_dropout_feed_dict = {
        dropout_rates[0]: 0.,
        dropout_rates[1]: 0.,
        dropout_rates[2]: 0.,
        dropout_rates[3]: 0.,
        dropout_rates[4]: 0.,
        dropout_rates[5]: 0.,
    }

    # Progress Bar
    def update_progressbar(set_name):
        if not hasattr(update_progressbar, 'current_set_name'):
            update_progressbar.current_set_name = None

        if (update_progressbar.current_set_name != set_name
                or update_progressbar.current_job_index
                == update_progressbar.total_jobs):

            # finish prev pbar if it exists
            if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar:
                update_progressbar.pbar.finish()

            update_progressbar.total_jobs = None
            update_progressbar.current_job_index = 0

            current_epoch = coord._epoch - 1

            if set_name == "train":
                log_info('Training epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_train
            else:
                log_info('Validating epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_dev

            # recreate pbar
            update_progressbar.pbar = progressbar.ProgressBar(
                max_value=update_progressbar.total_jobs,
                redirect_stdout=True).start()

            update_progressbar.current_set_name = set_name

        if update_progressbar.pbar:
            update_progressbar.pbar.update(
                update_progressbar.current_job_index + 1, force=True)

        update_progressbar.current_job_index += 1

    # Initialize update_progressbar()'s child fields to safe values
    update_progressbar.pbar = None

    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    try:
        with tf.train.MonitoredTrainingSession(
                master='' if server is None else server.target,
                is_chief=Config.is_chief,
                hooks=hooks,
                checkpoint_dir=FLAGS.checkpoint_dir,
                save_checkpoint_secs=None,  # already taken care of by a hook
                log_step_count_steps=
                0,  # disable logging of steps/s to avoid TF warning in validation sets
                config=Config.session_config) as session:
            tf.get_default_graph().finalize()

            try:
                if Config.is_chief:
                    # Retrieving global_step from the (potentially restored) model
                    model_feeder.set_data_set(no_dropout_feed_dict,
                                              model_feeder.train)
                    step = session.run(global_step,
                                       feed_dict=no_dropout_feed_dict)
                    coord.start_coordination(model_feeder, step)

                # Get the first job
                job = coord.get_job()

                while job and not session.should_stop():
                    log_debug('Computing %s...' % job)

                    is_train = job.set_name == 'train'

                    # The feed_dict (mainly for switching between queues)
                    if is_train:
                        feed_dict = {
                            dropout_rates[0]: FLAGS.dropout_rate,
                            dropout_rates[1]: FLAGS.dropout_rate2,
                            dropout_rates[2]: FLAGS.dropout_rate3,
                            dropout_rates[3]: FLAGS.dropout_rate4,
                            dropout_rates[4]: FLAGS.dropout_rate5,
                            dropout_rates[5]: FLAGS.dropout_rate6,
                        }
                    else:
                        feed_dict = no_dropout_feed_dict

                    # Sets the current data_set for the respective placeholder in feed_dict
                    model_feeder.set_data_set(
                        feed_dict, getattr(model_feeder, job.set_name))

                    # Initialize loss aggregator
                    total_loss = 0.0

                    # Setting the training operation in case of training requested
                    train_op = apply_gradient_op if is_train else []

                    # So far the only extra parameter is the feed_dict
                    extra_params = {'feed_dict': feed_dict}

                    step_summary_writer = step_summary_writers.get(
                        job.set_name)

                    # Loop over the batches
                    for job_step in range(job.steps):
                        if session.should_stop():
                            break

                        log_debug('Starting batch...')
                        # Compute the batch
                        _, current_step, batch_loss, step_summary = session.run(
                            [train_op, global_step, loss, step_summaries_op],
                            **extra_params)

                        # Log step summaries
                        step_summary_writer.add_summary(
                            step_summary, current_step)

                        # Uncomment the next line for debugging race conditions / distributed TF
                        log_debug('Finished batch step %d.' % current_step)

                        # Add batch to loss
                        total_loss += batch_loss

                    # Gathering job results
                    job.loss = total_loss / job.steps

                    # Display progressbar
                    if FLAGS.show_progressbar:
                        update_progressbar(job.set_name)

                    # Send the current job to coordinator and receive the next one
                    log_debug('Sending %s...' % job)
                    job = coord.next_job(job)

                if update_progressbar.pbar:
                    update_progressbar.pbar.finish()

            except Exception as e:
                log_error(str(e))
                traceback.print_exc()
                # Calling all hook's end() methods to end blocking calls
                for hook in hooks:
                    hook.end(session)
                # Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit.
                # A rather graceful way to do this is by stopping the ps.
                # Only one party can send it w/o failing.
                if Config.is_chief:
                    send_token_to_ps(session, kill=True)
                sys.exit(1)

        log_debug('Session closed.')

    except tf.errors.InvalidArgumentError as e:
        log_error(str(e))
        log_error(
            'The checkpoint in {0} does not match the shapes of the model.'
            ' Did you change alphabet.txt or the --n_hidden parameter'
            ' between train runs using the same checkpoint dir? Try moving'
            ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
        sys.exit(1)

    # Stopping the coordinator
    coord.stop()
Beispiel #32
0
def _load_checkpoint(session, checkpoint_path):
    # Load the checkpoint and put all variables into loading list
    # we will exclude variables we do not wish to load and then
    # we will initialize them instead
    ckpt = tfv1.train.load_checkpoint(checkpoint_path)
    vars_in_ckpt = frozenset(ckpt.get_variable_to_shape_map().keys())
    load_vars = set(tfv1.global_variables())
    init_vars = set()

    # We explicitly allow the learning rate variable to be missing for backwards
    # compatibility with older checkpoints.
    lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
    if lr_var and ('learning_rate' not in vars_in_ckpt or FLAGS.force_initialize_learning_rate):
        assert len(lr_var) <= 1
        load_vars -= lr_var
        init_vars |= lr_var

    if FLAGS.load_cudnn:
        # Initialize training from a CuDNN RNN checkpoint
        # Identify the variables which we cannot load, and set them
        # for initialization
        missing_vars = set()
        for v in load_vars:
            if v.op.name not in vars_in_ckpt:
                log_warn('CUDNN variable not found: %s' % (v.op.name))
                missing_vars.add(v)
                init_vars.add(v)

        load_vars -= init_vars

        # Check that the only missing variables (i.e. those to be initialised)
        # are the Adam moment tensors, if they aren't then we have an issue
        missing_var_names = [v.op.name for v in missing_vars]
        if any('Adam' not in v for v in missing_var_names):
            log_error('Tried to load a CuDNN RNN checkpoint but there were '
                      'more missing variables than just the Adam moment '
                      'tensors. Missing variables: {}'.format(missing_var_names))
            sys.exit(1)

    if FLAGS.drop_source_layers > 0:
        # This transfer learning approach requires supplying
        # the layers which we exclude from the source model.
        # Say we want to exclude all layers except for the first one,
        # then we are dropping five layers total, so: drop_source_layers=5
        # If we want to use all layers from the source model except
        # the last one, we use this: drop_source_layers=1
        if FLAGS.drop_source_layers >= 6:
            log_warn('The checkpoint only has 6 layers, but you are trying to drop '
                     'all of them or more than all of them. Continuing and '
                     'dropping only 5 layers.')
            FLAGS.drop_source_layers = 5

        dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):]
        # Initialize all variables needed for DS, but not loaded from ckpt
        for v in load_vars:
            if any(layer in v.op.name for layer in dropped_layers):
                init_vars.add(v)
        load_vars -= init_vars

    for v in sorted(load_vars, key=lambda v: v.op.name):
        log_info('Loading variable from checkpoint: %s' % (v.op.name))
        v.load(ckpt.get_tensor(v.op.name), session=session)

    for v in sorted(init_vars, key=lambda v: v.op.name):
        log_info('Initializing variable: %s' % (v.op.name))
        session.run(v.initializer)
Beispiel #33
0
def evaluate(test_csvs, create_model, try_loading):
    scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
                    FLAGS.lm_binary_path, FLAGS.lm_trie_path,
                    Config.alphabet)

    test_csvs = FLAGS.test_files.split(',')
    test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
    iterator = tf.data.Iterator.from_structure(test_sets[0].output_types,
                                               test_sets[0].output_shapes,
                                               output_classes=test_sets[0].output_classes)
    test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]

    (batch_x, batch_x_len), batch_y = iterator.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))

    loss = tf.nn.ctc_loss(labels=batch_y,
                          inputs=logits,
                          sequence_length=batch_x_len)

    tf.train.get_or_create_global_step()

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except NotImplementedError:
        num_processes = 1

    # Create a saver using variables from the above newly created graph
    saver = tf.train.Saver()

    with tf.Session(config=Config.session_config) as session:
        # Restore variables from training checkpoint
        loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
        if not loaded:
            loaded = try_loading(session, saver, 'checkpoint', 'most recent')
        if not loaded:
            log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
            exit(1)

        def run_test(init_op, dataset):
            logitses = []
            losses = []
            seq_lengths = []
            ground_truths = []

            bar = create_progressbar(prefix='Computing acoustic model predictions | ',
                                     widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
            log_progress('Computing acoustic model predictions...')

            step_count = 0

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # First pass, compute losses and transposed logits for decoding
            while True:
                try:
                    logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
                except tf.errors.OutOfRangeError:
                    break

                step_count += 1
                bar.update(step_count)

                logitses.append(logits)
                losses.extend(loss_)
                seq_lengths.append(lengths)
                ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))

            bar.finish()

            predictions = []

            bar = create_progressbar(max_value=step_count,
                                     prefix='Decoding predictions | ').start()
            log_progress('Decoding predictions...')

            # Second pass, decode logits and compute WER and edit distance metrics
            for logits, seq_length in bar(zip(logitses, seq_lengths)):
                decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
                                                        num_processes=num_processes, scorer=scorer)
                predictions.extend(d[0][1] for d in decoded)

            distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]

            wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
            mean_loss = np.mean(losses)

            # Take only the first report_count items
            report_samples = itertools.islice(samples, FLAGS.report_count)

            print('Test on %s - WER: %f, CER: %f, loss: %f' %
                  (dataset, wer, cer, mean_loss))
            print('-' * 80)
            for sample in report_samples:
                print('WER: %f, CER: %f, loss: %f' %
                      (sample.wer, sample.distance, sample.loss))
                print(' - src: "%s"' % sample.src)
                print(' - res: "%s"' % sample.res)
                print('-' * 80)

            return samples

        samples = []
        for csv, init_op in zip(test_csvs, test_init_ops):
            print('Testing model on {}'.format(csv))
            samples.extend(run_test(init_op, dataset=csv))
        return samples
Beispiel #34
0
        fout.write(FLAGS.flags_into_string())

    initializer = tfv1.global_variables_initializer()

    with tfv1.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        # Loading or initializing
        loaded = False

        # Initialize training from a CuDNN RNN checkpoint
        if FLAGS.cudnn_checkpoint:
            if FLAGS.use_cudnn_rnn:
                log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
                          'was specified. The --cudnn_checkpoint flag is only '
                          'needed when converting a CuDNN RNN checkpoint to '
                          'a CPU-capable graph. If your system is capable of '
                          'using CuDNN RNN, you can just specify the CuDNN RNN '
                          'checkpoint normally with --checkpoint_dir.')
                exit(1)

            log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
            ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
            missing_variables = []

            # Load compatible variables from checkpoint
            for v in tfv1.global_variables():
                try:
                    v.load(ckpt.get_tensor(v.op.name), session=session)
                except tf.errors.NotFoundError:
                    missing_variables.append(v)