def end(self, session): # Closing the data_set queues log_debug('Closing queues...') model_feeder.close_queues(session) log_debug('Queues closed.') # Telling the ps that we are done send_token_to_ps(session)
def send_token_to_ps(session, kill=False): # Sending our token (the task_index as a debug opportunity) to each parameter server. # kill switch tokens are negative and decremented by 1 to deal with task_index 0 token = -FLAGS.task_index-1 if kill else FLAGS.task_index kind = 'kill switch' if kill else 'stop' for index, enqueue in enumerate(Config.done_enqueues): log_debug('Sending %s token to ps %d...' % (kind, index)) session.run(enqueue, feed_dict={ Config.token_placeholder: token }) log_debug('Sent %s token to ps %d.' % (kind, index))
def start(self): '''Starts Training Coordinator. If chief, it starts a web server for communication with non-chief instances. ''' if self.is_chief: log_debug('Starting coordinator...') self._thread = Thread(target=self._httpd.serve_forever) self._thread.daemon = True self._thread.start() log_debug('Coordinator started. Thread id {}'.format(self._thread.ident))
def _next_epoch(self): # State-machine of the coordination process # Indicates, if there were 'new' epoch(s) provided result = False # Make sure that early stop is enabled and validation part is enabled if (FLAGS.early_stop is True) and (FLAGS.validation_step > 0) and (len(self._dev_losses) >= FLAGS.earlystop_nsteps): # Calculate the mean of losses for past epochs mean_loss = np.mean(self._dev_losses[-FLAGS.earlystop_nsteps:-1]) # Calculate the standard deviation for losses from validation part in the past epochs std_loss = np.std(self._dev_losses[-FLAGS.earlystop_nsteps:-1]) # Update the list of losses incurred self._dev_losses = self._dev_losses[-FLAGS.earlystop_nsteps:] log_debug('Checking for early stopping (last %d steps) validation loss: %f, with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss)) # Check if validation loss has started increasing or is not decreasing substantially, making sure slight fluctuations don't bother the early stopping from working if self._dev_losses[-1] > np.max(self._dev_losses[:-1]) or (abs(self._dev_losses[-1] - mean_loss) < FLAGS.estop_mean_thresh and std_loss < FLAGS.estop_std_thresh): # Time to early stop log_info('Early stop triggered as (for last %d steps) validation loss: %f with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss)) self._dev_losses = [] self._end_training() self._train = False if self._train: # We are in train mode if self._num_jobs_train_left > 0: # There are still jobs left num_jobs_train = min(self._num_jobs_train_left, self._num_jobs_train) self._num_jobs_train_left -= num_jobs_train # Let's try our best to keep the notion of curriculum learning self._reset_counters() # Append the training epoch self._epochs_running.append(Epoch(self, self._epoch, num_jobs_train, set_name='train')) if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0: # The current epoch should also have a validation part self._epochs_running.append(Epoch(self, self._epoch, self._num_jobs_dev, set_name='dev')) # Indicating that there were 'new' epoch(s) provided result = True else: # No jobs left, but still in train mode: concluding training self._end_training() self._train = False if result: # Increment the epoch index self._epoch += 1 return result
def stop(self, wait_for_running_epochs=True): '''Stops Training Coordinator. If chief, it waits for all epochs to be 'done' and then shuts down the web server. ''' if self.is_chief and self._thread: if wait_for_running_epochs: while len(self._epochs_running) > 0: log_traffic('Coordinator is waiting for epochs to finish...') time.sleep(5) log_debug('Stopping coordinator...') self._httpd.shutdown() log_debug('Coordinator stopped.')
def main(_): initialize_globals() if FLAGS.train or FLAGS.test: if len(FLAGS.worker_hosts) == 0: # Only one local task: this process (default case - no cluster) with tf.Graph().as_default(): tf.set_random_seed(FLAGS.random_seed) train() # Now do a final test epoch if FLAGS.test: with tf.Graph().as_default(): test() log_debug('Done.') else: # Create and start a server for the local task. server = tf.train.Server(Config.cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == 'ps': # We are a parameter server and therefore we just wait for all workers to finish # by waiting for their stop tokens. with tf.Session(server.target) as session: for worker in FLAGS.worker_hosts: log_debug('Waiting for stop token...') token = session.run(Config.done_dequeues[FLAGS.task_index]) if token < 0: log_debug('Got a kill switch token from worker %i.' % abs(token + 1)) break log_debug('Got a stop token from worker %i.' % token) log_debug('Session closed.') if FLAGS.test: test() elif FLAGS.job_name == 'worker': # We are a worker and therefore we have to do some work. # Assigns ops to the local worker by default. with tf.device(tf.train.replica_device_setter( worker_device=Config.worker_device, cluster=Config.cluster)): # Do the training train(server) log_debug('Server stopped.') # Are we the main process? if Config.is_chief: # Doing solo/post-processing work just on the main process... # Exporting the model if FLAGS.export_dir: export() if len(FLAGS.one_shot_infer): do_single_file_inference(FLAGS.one_shot_infer)
def after_create_session(self, session, coord): log_debug('Starting queue runners...') model_feeder.start_queue_threads(session, coord) log_debug('Queue runners started.')
def train(server=None): r''' Trains the network on a given server of a cluster. If no server provided, it performs single process training. ''' # Initializing and starting the training coordinator coord = TrainingCoordinator(Config.is_chief) coord.start() # Create a variable to hold the global_step. # It will automagically get incremented by the optimizer. global_step = tf.Variable(0, trainable=False, name='global_step') dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] # Reading training set train_data = preprocess(FLAGS.train_files.split(','), FLAGS.train_batch_size, Config.n_input, Config.n_context, Config.alphabet, hdf5_cache_path=FLAGS.train_cached_features_path) train_set = DataSet(train_data, FLAGS.train_batch_size, limit=FLAGS.limit_train, next_index=lambda i: coord.get_next_index('train')) # Reading validation set dev_data = preprocess(FLAGS.dev_files.split(','), FLAGS.dev_batch_size, Config.n_input, Config.n_context, Config.alphabet, hdf5_cache_path=FLAGS.dev_cached_features_path) dev_set = DataSet(dev_data, FLAGS.dev_batch_size, limit=FLAGS.limit_dev, next_index=lambda i: coord.get_next_index('dev')) # Combining all sets to a multi set model feeder model_feeder = ModelFeeder(train_set, dev_set, Config.n_input, Config.n_context, Config.alphabet, tower_feeder_count=len(Config.available_devices)) # Create the optimizer optimizer = create_optimizer() # Synchronous distributed training is facilitated by a special proxy-optimizer if not server is None: optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.replicas_to_agg, total_num_replicas=FLAGS.replicas) # Get the data_set specific graph end-points gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) # Add summaries of all variables and gradients to log log_grads_and_vars(avg_tower_gradients) # Op to merge all summaries for the summary hook merge_all_summaries_op = tf.summary.merge_all() # These are saved on every step step_summaries_op = tf.summary.merge_all('step_summaries') step_summary_writers = { 'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Apply gradients to modify the model apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) if FLAGS.early_stop is True and not FLAGS.validation_step > 0: log_warn('Parameter --validation_step needs to be >0 for early stopping to work') class CoordHook(tf.train.SessionRunHook): r''' Embedded coordination hook-class that will use variables of the surrounding Python context. ''' def after_create_session(self, session, coord): log_debug('Starting queue runners...') model_feeder.start_queue_threads(session, coord) log_debug('Queue runners started.') def end(self, session): # Closing the data_set queues log_debug('Closing queues...') model_feeder.close_queues(session) log_debug('Queues closed.') # Telling the ps that we are done send_token_to_ps(session) # Collecting the hooks hooks = [CoordHook()] # Hook to handle initialization and queues for sync replicas. if not server is None: hooks.append(optimizer.make_session_run_hook(Config.is_chief)) # Hook to save TensorBoard summaries if FLAGS.summary_secs > 0: hooks.append(tf.train.SummarySaverHook(save_secs=FLAGS.summary_secs, output_dir=FLAGS.summary_dir, summary_op=merge_all_summaries_op)) # Hook wih number of checkpoint files to save in checkpoint_dir if FLAGS.train and FLAGS.max_to_keep > 0: saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir, save_secs=FLAGS.checkpoint_secs, saver=saver)) no_dropout_feed_dict = { dropout_rates[0]: 0., dropout_rates[1]: 0., dropout_rates[2]: 0., dropout_rates[3]: 0., dropout_rates[4]: 0., dropout_rates[5]: 0., } # Progress Bar def update_progressbar(set_name): if not hasattr(update_progressbar, 'current_set_name'): update_progressbar.current_set_name = None if (update_progressbar.current_set_name != set_name or update_progressbar.current_job_index == update_progressbar.total_jobs): # finish prev pbar if it exists if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar: update_progressbar.pbar.finish() update_progressbar.total_jobs = None update_progressbar.current_job_index = 0 current_epoch = coord._epoch-1 if set_name == "train": log_info('Training epoch %i...' % current_epoch) update_progressbar.total_jobs = coord._num_jobs_train else: log_info('Validating epoch %i...' % current_epoch) update_progressbar.total_jobs = coord._num_jobs_dev # recreate pbar update_progressbar.pbar = progressbar.ProgressBar(max_value=update_progressbar.total_jobs, redirect_stdout=True).start() update_progressbar.current_set_name = set_name if update_progressbar.pbar: update_progressbar.pbar.update(update_progressbar.current_job_index+1, force=True) update_progressbar.current_job_index += 1 # Initialize update_progressbar()'s child fields to safe values update_progressbar.pbar = None # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. try: with tf.train.MonitoredTrainingSession(master='' if server is None else server.target, is_chief=Config.is_chief, hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir, save_checkpoint_secs=None, # already taken care of by a hook log_step_count_steps=0, # disable logging of steps/s to avoid TF warning in validation sets config=Config.session_config) as session: tf.get_default_graph().finalize() try: if Config.is_chief: # Retrieving global_step from the (potentially restored) model model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train) step = session.run(global_step, feed_dict=no_dropout_feed_dict) coord.start_coordination(model_feeder, step) # Get the first job job = coord.get_job() while job and not session.should_stop(): log_debug('Computing %s...' % job) is_train = job.set_name == 'train' # The feed_dict (mainly for switching between queues) if is_train: feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } else: feed_dict = no_dropout_feed_dict # Sets the current data_set for the respective placeholder in feed_dict model_feeder.set_data_set(feed_dict, getattr(model_feeder, job.set_name)) # Initialize loss aggregator total_loss = 0.0 # Setting the training operation in case of training requested train_op = apply_gradient_op if is_train else [] # So far the only extra parameter is the feed_dict extra_params = { 'feed_dict': feed_dict } step_summary_writer = step_summary_writers.get(job.set_name) # Loop over the batches for job_step in range(job.steps): if session.should_stop(): break log_debug('Starting batch...') # Compute the batch _, current_step, batch_loss, step_summary = session.run([train_op, global_step, loss, step_summaries_op], **extra_params) # Log step summaries step_summary_writer.add_summary(step_summary, current_step) # Uncomment the next line for debugging race conditions / distributed TF log_debug('Finished batch step %d.' % current_step) # Add batch to loss total_loss += batch_loss # Gathering job results job.loss = total_loss / job.steps # Display progressbar if FLAGS.show_progressbar: update_progressbar(job.set_name) # Send the current job to coordinator and receive the next one log_debug('Sending %s...' % job) job = coord.next_job(job) if update_progressbar.pbar: update_progressbar.pbar.finish() except Exception as e: log_error(str(e)) traceback.print_exc() # Calling all hook's end() methods to end blocking calls for hook in hooks: hook.end(session) # Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit. # A rather graceful way to do this is by stopping the ps. # Only one party can send it w/o failing. if Config.is_chief: send_token_to_ps(session, kill=True) sys.exit(1) log_debug('Session closed.') except tf.errors.InvalidArgumentError as e: log_error(str(e)) log_error('The checkpoint in {0} does not match the shapes of the model.' ' Did you change alphabet.txt or the --n_hidden parameter' ' between train runs using the same checkpoint dir? Try moving' ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) sys.exit(1) # Stopping the coordinator coord.stop()
def start_coordination(self, model_feeder, step=0): '''Starts to coordinate epochs and jobs among workers on base of data-set sizes, the (global) step and FLAGS parameters. Args: model_feeder (ModelFeeder): data-sets to be used for coordinated training Kwargs: step (int): global step of a loaded model to determine starting point ''' with self._lock: self._init() # Number of GPUs per worker - fixed for now by local reality or cluster setup gpus_per_worker = len(Config.available_devices) # Number of batches processed per job per worker batches_per_job = gpus_per_worker * max(1, FLAGS.iters_per_worker) # Number of batches per global step batches_per_step = gpus_per_worker * max(1, FLAGS.replicas_to_agg) # Number of global steps per epoch - to be at least 1 steps_per_epoch = max(1, model_feeder.train.total_batches // batches_per_step) # The start epoch of our training self._epoch = step // steps_per_epoch # Number of additional 'jobs' trained already 'on top of' our start epoch jobs_trained = (step % steps_per_epoch) * batches_per_step // batches_per_job # Total number of train/dev jobs covering their respective whole sets (one epoch) self._num_jobs_train = max(1, model_feeder.train.total_batches // batches_per_job) self._num_jobs_dev = max(1, model_feeder.dev.total_batches // batches_per_job) if FLAGS.epoch < 0: # A negative epoch means to add its absolute number to the epochs already computed self._target_epoch = self._epoch + abs(FLAGS.epoch) else: self._target_epoch = FLAGS.epoch # State variables # We only have to train, if we are told so and are not at the target epoch yet self._train = FLAGS.train and self._target_epoch > self._epoch if self._train: # The total number of jobs for all additional epochs to be trained # Will be decremented for each job that is produced/put into state 'open' self._num_jobs_train_left = (self._target_epoch - self._epoch) * self._num_jobs_train - jobs_trained log_info('STARTING Optimization') self._training_time = stopwatch() # Important for debugging log_debug('step: %d' % step) log_debug('epoch: %d' % self._epoch) log_debug('target epoch: %d' % self._target_epoch) log_debug('steps per epoch: %d' % steps_per_epoch) log_debug('number of batches in train set: %d' % model_feeder.train.total_batches) log_debug('batches per job: %d' % batches_per_job) log_debug('batches per step: %d' % batches_per_step) log_debug('number of jobs in train set: %d' % self._num_jobs_train) log_debug('number of jobs already trained in first epoch: %d' % jobs_trained) self._next_epoch() # The coordinator is ready to serve self.started = True
def _log_all_jobs(self): '''Use this to debug-print epoch state''' log_debug('Epochs - running: %d, done: %d' % (len(self._epochs_running), len(self._epochs_done))) for epoch in self._epochs_running: log_debug(' - running: ' + epoch.job_status())