def end(self, session):
            # Closing the data_set queues
            log_debug('Closing queues...')
            model_feeder.close_queues(session)
            log_debug('Queues closed.')

            # Telling the ps that we are done
            send_token_to_ps(session)
def send_token_to_ps(session, kill=False):
    # Sending our token (the task_index as a debug opportunity) to each parameter server.
    # kill switch tokens are negative and decremented by 1 to deal with task_index 0
    token = -FLAGS.task_index-1 if kill else FLAGS.task_index
    kind = 'kill switch' if kill else 'stop'
    for index, enqueue in enumerate(Config.done_enqueues):
        log_debug('Sending %s token to ps %d...' % (kind, index))
        session.run(enqueue, feed_dict={ Config.token_placeholder: token })
        log_debug('Sent %s token to ps %d.' % (kind, index))
 def start(self):
     '''Starts Training Coordinator. If chief, it starts a web server for
     communication with non-chief instances.
     '''
     if self.is_chief:
         log_debug('Starting coordinator...')
         self._thread = Thread(target=self._httpd.serve_forever)
         self._thread.daemon = True
         self._thread.start()
         log_debug('Coordinator started. Thread id {}'.format(self._thread.ident))
    def _next_epoch(self):
        # State-machine of the coordination process

        # Indicates, if there were 'new' epoch(s) provided
        result = False

        # Make sure that early stop is enabled and validation part is enabled
        if (FLAGS.early_stop is True) and (FLAGS.validation_step > 0) and (len(self._dev_losses) >= FLAGS.earlystop_nsteps):

            # Calculate the mean of losses for past epochs
            mean_loss = np.mean(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
            # Calculate the standard deviation for losses from validation part in the past epochs
            std_loss = np.std(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
            # Update the list of losses incurred
            self._dev_losses = self._dev_losses[-FLAGS.earlystop_nsteps:]
            log_debug('Checking for early stopping (last %d steps) validation loss: %f, with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))

            # Check if validation loss has started increasing or is not decreasing substantially, making sure slight fluctuations don't bother the early stopping from working
            if self._dev_losses[-1] > np.max(self._dev_losses[:-1]) or (abs(self._dev_losses[-1] - mean_loss) < FLAGS.estop_mean_thresh and std_loss < FLAGS.estop_std_thresh):
                # Time to early stop
                log_info('Early stop triggered as (for last %d steps) validation loss: %f with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))
                self._dev_losses = []
                self._end_training()
                self._train = False

        if self._train:
            # We are in train mode
            if self._num_jobs_train_left > 0:
                # There are still jobs left
                num_jobs_train = min(self._num_jobs_train_left, self._num_jobs_train)
                self._num_jobs_train_left -= num_jobs_train

                # Let's try our best to keep the notion of curriculum learning
                self._reset_counters()

                # Append the training epoch
                self._epochs_running.append(Epoch(self, self._epoch, num_jobs_train, set_name='train'))

                if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0:
                    # The current epoch should also have a validation part
                    self._epochs_running.append(Epoch(self, self._epoch, self._num_jobs_dev, set_name='dev'))


                # Indicating that there were 'new' epoch(s) provided
                result = True
            else:
                # No jobs left, but still in train mode: concluding training
                self._end_training()
                self._train = False

        if result:
            # Increment the epoch index
            self._epoch += 1
        return result
 def stop(self, wait_for_running_epochs=True):
     '''Stops Training Coordinator. If chief, it waits for all epochs to be
     'done' and then shuts down the web server.
     '''
     if self.is_chief and self._thread:
         if wait_for_running_epochs:
             while len(self._epochs_running) > 0:
                 log_traffic('Coordinator is waiting for epochs to finish...')
                 time.sleep(5)
         log_debug('Stopping coordinator...')
         self._httpd.shutdown()
         log_debug('Coordinator stopped.')
def main(_):
    initialize_globals()

    if FLAGS.train or FLAGS.test:
        if len(FLAGS.worker_hosts) == 0:
            # Only one local task: this process (default case - no cluster)
            with tf.Graph().as_default():
                tf.set_random_seed(FLAGS.random_seed)
                train()
            # Now do a final test epoch
            if FLAGS.test:
                with tf.Graph().as_default():
                    test()
            log_debug('Done.')
        else:
            # Create and start a server for the local task.
            server = tf.train.Server(Config.cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
            if FLAGS.job_name == 'ps':
                # We are a parameter server and therefore we just wait for all workers to finish
                # by waiting for their stop tokens.
                with tf.Session(server.target) as session:
                    for worker in FLAGS.worker_hosts:
                        log_debug('Waiting for stop token...')
                        token = session.run(Config.done_dequeues[FLAGS.task_index])
                        if token < 0:
                            log_debug('Got a kill switch token from worker %i.' % abs(token + 1))
                            break
                        log_debug('Got a stop token from worker %i.' % token)
                log_debug('Session closed.')

                if FLAGS.test:
                    test()
            elif FLAGS.job_name == 'worker':
                # We are a worker and therefore we have to do some work.
                # Assigns ops to the local worker by default.
                with tf.device(tf.train.replica_device_setter(
                               worker_device=Config.worker_device,
                               cluster=Config.cluster)):

                    # Do the training
                    train(server)

            log_debug('Server stopped.')

    # Are we the main process?
    if Config.is_chief:
        # Doing solo/post-processing work just on the main process...
        # Exporting the model
        if FLAGS.export_dir:
            export()

    if len(FLAGS.one_shot_infer):
        do_single_file_inference(FLAGS.one_shot_infer)
 def after_create_session(self, session, coord):
     log_debug('Starting queue runners...')
     model_feeder.start_queue_threads(session, coord)
     log_debug('Queue runners started.')
def train(server=None):
    r'''
    Trains the network on a given server of a cluster.
    If no server provided, it performs single process training.
    '''

    # Initializing and starting the training coordinator
    coord = TrainingCoordinator(Config.is_chief)
    coord.start()

    # Create a variable to hold the global_step.
    # It will automagically get incremented by the optimizer.
    global_step = tf.Variable(0, trainable=False, name='global_step')

    dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]

    # Reading training set
    train_data = preprocess(FLAGS.train_files.split(','),
                            FLAGS.train_batch_size,
                            Config.n_input,
                            Config.n_context,
                            Config.alphabet,
                            hdf5_cache_path=FLAGS.train_cached_features_path)

    train_set = DataSet(train_data,
                        FLAGS.train_batch_size,
                        limit=FLAGS.limit_train,
                        next_index=lambda i: coord.get_next_index('train'))

    # Reading validation set
    dev_data = preprocess(FLAGS.dev_files.split(','),
                          FLAGS.dev_batch_size,
                          Config.n_input,
                          Config.n_context,
                          Config.alphabet,
                          hdf5_cache_path=FLAGS.dev_cached_features_path)

    dev_set = DataSet(dev_data,
                      FLAGS.dev_batch_size,
                      limit=FLAGS.limit_dev,
                      next_index=lambda i: coord.get_next_index('dev'))

    # Combining all sets to a multi set model feeder
    model_feeder = ModelFeeder(train_set,
                               dev_set,
                               Config.n_input,
                               Config.n_context,
                               Config.alphabet,
                               tower_feeder_count=len(Config.available_devices))

    # Create the optimizer
    optimizer = create_optimizer()

    # Synchronous distributed training is facilitated by a special proxy-optimizer
    if not server is None:
        optimizer = tf.train.SyncReplicasOptimizer(optimizer,
                                                   replicas_to_aggregate=FLAGS.replicas_to_agg,
                                                   total_num_replicas=FLAGS.replicas)

    # Get the data_set specific graph end-points
    gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates)

    # Average tower gradients across GPUs
    avg_tower_gradients = average_gradients(gradients)

    # Add summaries of all variables and gradients to log
    log_grads_and_vars(avg_tower_gradients)

    # Op to merge all summaries for the summary hook
    merge_all_summaries_op = tf.summary.merge_all()

    # These are saved on every step
    step_summaries_op = tf.summary.merge_all('step_summaries')

    step_summary_writers = {
        'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
        'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
    }

    # Apply gradients to modify the model
    apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)


    if FLAGS.early_stop is True and not FLAGS.validation_step > 0:
        log_warn('Parameter --validation_step needs to be >0 for early stopping to work')

    class CoordHook(tf.train.SessionRunHook):
        r'''
        Embedded coordination hook-class that will use variables of the
        surrounding Python context.
        '''
        def after_create_session(self, session, coord):
            log_debug('Starting queue runners...')
            model_feeder.start_queue_threads(session, coord)
            log_debug('Queue runners started.')

        def end(self, session):
            # Closing the data_set queues
            log_debug('Closing queues...')
            model_feeder.close_queues(session)
            log_debug('Queues closed.')

            # Telling the ps that we are done
            send_token_to_ps(session)

    # Collecting the hooks
    hooks = [CoordHook()]

    # Hook to handle initialization and queues for sync replicas.
    if not server is None:
        hooks.append(optimizer.make_session_run_hook(Config.is_chief))

    # Hook to save TensorBoard summaries
    if FLAGS.summary_secs > 0:
        hooks.append(tf.train.SummarySaverHook(save_secs=FLAGS.summary_secs, output_dir=FLAGS.summary_dir, summary_op=merge_all_summaries_op))

    # Hook wih number of checkpoint files to save in checkpoint_dir
    if FLAGS.train and FLAGS.max_to_keep > 0:
        saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
        hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir, save_secs=FLAGS.checkpoint_secs, saver=saver))

    no_dropout_feed_dict = {
        dropout_rates[0]: 0.,
        dropout_rates[1]: 0.,
        dropout_rates[2]: 0.,
        dropout_rates[3]: 0.,
        dropout_rates[4]: 0.,
        dropout_rates[5]: 0.,
    }

    # Progress Bar
    def update_progressbar(set_name):
        if not hasattr(update_progressbar, 'current_set_name'):
            update_progressbar.current_set_name = None

        if (update_progressbar.current_set_name != set_name or
            update_progressbar.current_job_index == update_progressbar.total_jobs):

            # finish prev pbar if it exists
            if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar:
                update_progressbar.pbar.finish()

            update_progressbar.total_jobs = None
            update_progressbar.current_job_index = 0

            current_epoch = coord._epoch-1

            if set_name == "train":
                log_info('Training epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_train
            else:
                log_info('Validating epoch %i...' % current_epoch)
                update_progressbar.total_jobs = coord._num_jobs_dev

            # recreate pbar
            update_progressbar.pbar = progressbar.ProgressBar(max_value=update_progressbar.total_jobs,
                                                              redirect_stdout=True).start()

            update_progressbar.current_set_name = set_name

        if update_progressbar.pbar:
            update_progressbar.pbar.update(update_progressbar.current_job_index+1, force=True)

        update_progressbar.current_job_index += 1

    # Initialize update_progressbar()'s child fields to safe values
    update_progressbar.pbar = None

    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    try:
        with tf.train.MonitoredTrainingSession(master='' if server is None else server.target,
                                               is_chief=Config.is_chief,
                                               hooks=hooks,
                                               checkpoint_dir=FLAGS.checkpoint_dir,
                                               save_checkpoint_secs=None, # already taken care of by a hook
                                               log_step_count_steps=0, # disable logging of steps/s to avoid TF warning in validation sets
                                               config=Config.session_config) as session:
            tf.get_default_graph().finalize()

            try:
                if Config.is_chief:
                    # Retrieving global_step from the (potentially restored) model
                    model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train)
                    step = session.run(global_step, feed_dict=no_dropout_feed_dict)
                    coord.start_coordination(model_feeder, step)

                # Get the first job
                job = coord.get_job()

                while job and not session.should_stop():
                    log_debug('Computing %s...' % job)

                    is_train = job.set_name == 'train'

                    # The feed_dict (mainly for switching between queues)
                    if is_train:
                        feed_dict = {
                            dropout_rates[0]: FLAGS.dropout_rate,
                            dropout_rates[1]: FLAGS.dropout_rate2,
                            dropout_rates[2]: FLAGS.dropout_rate3,
                            dropout_rates[3]: FLAGS.dropout_rate4,
                            dropout_rates[4]: FLAGS.dropout_rate5,
                            dropout_rates[5]: FLAGS.dropout_rate6,
                        }
                    else:
                        feed_dict = no_dropout_feed_dict

                    # Sets the current data_set for the respective placeholder in feed_dict
                    model_feeder.set_data_set(feed_dict, getattr(model_feeder, job.set_name))

                    # Initialize loss aggregator
                    total_loss = 0.0

                    # Setting the training operation in case of training requested
                    train_op = apply_gradient_op if is_train else []

                    # So far the only extra parameter is the feed_dict
                    extra_params = { 'feed_dict': feed_dict }

                    step_summary_writer = step_summary_writers.get(job.set_name)

                    # Loop over the batches
                    for job_step in range(job.steps):
                        if session.should_stop():
                            break

                        log_debug('Starting batch...')
                        # Compute the batch
                        _, current_step, batch_loss, step_summary = session.run([train_op, global_step, loss, step_summaries_op], **extra_params)

                        # Log step summaries
                        step_summary_writer.add_summary(step_summary, current_step)

                        # Uncomment the next line for debugging race conditions / distributed TF
                        log_debug('Finished batch step %d.' % current_step)

                        # Add batch to loss
                        total_loss += batch_loss

                    # Gathering job results
                    job.loss = total_loss / job.steps

                    # Display progressbar
                    if FLAGS.show_progressbar:
                        update_progressbar(job.set_name)

                    # Send the current job to coordinator and receive the next one
                    log_debug('Sending %s...' % job)
                    job = coord.next_job(job)

                if update_progressbar.pbar:
                    update_progressbar.pbar.finish()

            except Exception as e:
                log_error(str(e))
                traceback.print_exc()
                # Calling all hook's end() methods to end blocking calls
                for hook in hooks:
                    hook.end(session)
                # Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit.
                # A rather graceful way to do this is by stopping the ps.
                # Only one party can send it w/o failing.
                if Config.is_chief:
                    send_token_to_ps(session, kill=True)
                sys.exit(1)

        log_debug('Session closed.')

    except tf.errors.InvalidArgumentError as e:
        log_error(str(e))
        log_error('The checkpoint in {0} does not match the shapes of the model.'
                  ' Did you change alphabet.txt or the --n_hidden parameter'
                  ' between train runs using the same checkpoint dir? Try moving'
                  ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
        sys.exit(1)

    # Stopping the coordinator
    coord.stop()
    def start_coordination(self, model_feeder, step=0):
        '''Starts to coordinate epochs and jobs among workers on base of
        data-set sizes, the (global) step and FLAGS parameters.

        Args:
            model_feeder (ModelFeeder): data-sets to be used for coordinated training

        Kwargs:
            step (int): global step of a loaded model to determine starting point
        '''
        with self._lock:
            self._init()

            # Number of GPUs per worker - fixed for now by local reality or cluster setup
            gpus_per_worker = len(Config.available_devices)

            # Number of batches processed per job per worker
            batches_per_job  = gpus_per_worker * max(1, FLAGS.iters_per_worker)

            # Number of batches per global step
            batches_per_step = gpus_per_worker * max(1, FLAGS.replicas_to_agg)

            # Number of global steps per epoch - to be at least 1
            steps_per_epoch = max(1, model_feeder.train.total_batches // batches_per_step)

            # The start epoch of our training
            self._epoch = step // steps_per_epoch

            # Number of additional 'jobs' trained already 'on top of' our start epoch
            jobs_trained = (step % steps_per_epoch) * batches_per_step // batches_per_job

            # Total number of train/dev jobs covering their respective whole sets (one epoch)
            self._num_jobs_train = max(1, model_feeder.train.total_batches // batches_per_job)
            self._num_jobs_dev   = max(1, model_feeder.dev.total_batches   // batches_per_job)

            if FLAGS.epoch < 0:
                # A negative epoch means to add its absolute number to the epochs already computed
                self._target_epoch = self._epoch + abs(FLAGS.epoch)
            else:
                self._target_epoch = FLAGS.epoch

            # State variables
            # We only have to train, if we are told so and are not at the target epoch yet
            self._train = FLAGS.train and self._target_epoch > self._epoch

            if self._train:
                # The total number of jobs for all additional epochs to be trained
                # Will be decremented for each job that is produced/put into state 'open'
                self._num_jobs_train_left = (self._target_epoch - self._epoch) * self._num_jobs_train - jobs_trained
                log_info('STARTING Optimization')
                self._training_time = stopwatch()

            # Important for debugging
            log_debug('step: %d' % step)
            log_debug('epoch: %d' % self._epoch)
            log_debug('target epoch: %d' % self._target_epoch)
            log_debug('steps per epoch: %d' % steps_per_epoch)
            log_debug('number of batches in train set: %d' % model_feeder.train.total_batches)
            log_debug('batches per job: %d' % batches_per_job)
            log_debug('batches per step: %d' % batches_per_step)
            log_debug('number of jobs in train set: %d' % self._num_jobs_train)
            log_debug('number of jobs already trained in first epoch: %d' % jobs_trained)

            self._next_epoch()

        # The coordinator is ready to serve
        self.started = True
 def _log_all_jobs(self):
     '''Use this to debug-print epoch state'''
     log_debug('Epochs - running: %d, done: %d' % (len(self._epochs_running), len(self._epochs_done)))
     for epoch in self._epochs_running:
         log_debug('       - running: ' + epoch.job_status())