def train(): r''' Trains the network on a given server of a cluster. If no server provided, it performs single process training. ''' # Reading training set train_index = SampleIndex() train_data = preprocess(FLAGS.train_files.split(','), FLAGS.train_batch_size, Config.n_input, Config.n_context, Config.alphabet, hdf5_cache_path=FLAGS.train_cached_features_path) train_set = DataSet(train_data, FLAGS.train_batch_size, limit=FLAGS.limit_train, next_index=train_index.inc) # Reading validation set dev_index = SampleIndex() dev_data = preprocess(FLAGS.dev_files.split(','), FLAGS.dev_batch_size, Config.n_input, Config.n_context, Config.alphabet, hdf5_cache_path=FLAGS.dev_cached_features_path) dev_set = DataSet(dev_data, FLAGS.dev_batch_size, limit=FLAGS.limit_dev, next_index=dev_index.inc) # Combining all sets to a multi set model feeder model_feeder = ModelFeeder(train_set, dev_set, Config.n_input, Config.n_context, Config.alphabet, tower_feeder_count=len( Config.available_devices)) # Dropout dropout_rates = [ tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6) ] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = { dropout_rates[0]: 0., dropout_rates[1]: 0., dropout_rates[2]: 0., dropout_rates[3]: 0., dropout_rates[4]: 0., dropout_rates[5]: 0., } # Building the graph optimizer = create_optimizer() gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) log_grads_and_vars(avg_tower_gradients) # global_step is automagically incremented by the optimizer global_step = tf.Variable(0, trainable=False, name='global_step') apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tf.summary.merge_all('step_summaries') step_summary_writers = { 'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Checkpointing checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train') checkpoint_filename = 'checkpoint' best_dev_saver = tf.train.Saver(max_to_keep=1) best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev') best_dev_filename = 'best_dev_checkpoint' initializer = tf.global_variables_initializer() with tf.Session(config=Config.session_config) as session: log_debug('Session opened.') tf.get_default_graph().finalize() # Loading or initializing loaded = False if FLAGS.load in ['auto', 'last']: loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent epoch') if not loaded and FLAGS.load in ['auto', 'best']: loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation') if not loaded: if FLAGS.load in ['auto', 'init']: log_info('Initializing...') session.run(initializer) else: log_error( 'Unable to load %s model from specified checkpoint dir' ' - consider using load option "auto" or "init".' % FLAGS.load) sys.exit(1) # Retrieving global_step from restored model and setting training parameters accordingly model_feeder.set_data_set(no_dropout_feed_dict, train_set) step = session.run(global_step, feed_dict=no_dropout_feed_dict) num_gpus = len(Config.available_devices) steps_per_epoch = max(1, train_set.total_batches // num_gpus) steps_trained = step % steps_per_epoch current_epoch = step // steps_per_epoch target_epoch = current_epoch + abs( FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch train_index.index = steps_trained * num_gpus log_debug('step: %d' % step) log_debug('epoch: %d' % current_epoch) log_debug('target epoch: %d' % target_epoch) log_debug('steps per epoch: %d' % steps_per_epoch) log_debug('batches per step (GPUs): %d' % num_gpus) log_debug('number of batches in train set: %d' % train_set.total_batches) log_debug('number of batches already trained in epoch: %d' % train_index.index) def run_set(set_name): data_set = getattr(model_feeder, set_name) is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict model_feeder.set_data_set(feed_dict, data_set) total_loss = 0.0 step_summary_writer = step_summary_writers.get(set_name) num_steps = max(1, data_set.total_batches // num_gpus) checkpoint_time = time.time() if FLAGS.show_progressbar: pbar = progressbar.ProgressBar(max_value=num_steps, redirect_stdout=True).start() # Batch loop for step_index in range(steps_trained, num_steps): if coord.should_stop(): break _, current_step, batch_loss, step_summary = \ session.run([train_op, global_step, loss, step_summaries_op], feed_dict=feed_dict) total_loss += batch_loss step_summary_writer.add_summary(step_summary, current_step) if FLAGS.show_progressbar: pbar.update(step_index + 1, force=True) if is_train and FLAGS.checkpoint_secs > 0 and time.time( ) - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() if FLAGS.show_progressbar: pbar.finish() return total_loss / num_steps if target_epoch > current_epoch: log_info('STARTING Optimization') best_dev_loss = float('inf') dev_losses = [] coord = tf.train.Coordinator() with coord.stop_on_exception(): log_debug('Starting queue runners...') model_feeder.start_queue_threads(session, coord=coord) log_debug('Queue runners started.') # Epoch loop for current_epoch in range(current_epoch, target_epoch): # Training if coord.should_stop(): break log_info('Training epoch %d ...' % current_epoch) train_loss = run_set('train') log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) steps_trained = 0 # Validation log_info('Validating epoch %d ...' % current_epoch) dev_loss = run_set('dev') dev_losses.append(dev_loss) log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss)) if dev_loss < best_dev_loss: best_dev_loss = dev_loss save_path = best_dev_saver.save( session, best_dev_path, latest_filename=best_dev_filename) log_info( "Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) dev_losses = dev_losses[-FLAGS.es_steps:] log_debug( 'Checking for early stopping (last %d steps) validation loss: ' '%f, with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) if dev_losses[-1] > np.max(dev_losses[:-1]) or \ (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): log_info( 'Early stop triggered as (for last %d steps) validation loss:' ' %f with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) break log_debug('Closing queues...') coord.request_stop() model_feeder.close_queues(session) log_debug('Queues closed.') else: log_info('Target epoch already reached - skipped training.') log_debug('Session closed.')
def train(server=None): r''' Trains the network on a given server of a cluster. If no server provided, it performs single process training. ''' # Initializing and starting the training coordinator coord = TrainingCoordinator(Config.is_chief) coord.start() # Create a variable to hold the global_step. # It will automagically get incremented by the optimizer. global_step = tf.Variable(0, trainable=False, name='global_step') dropout_rates = [ tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6) ] # Reading training set train_data = preprocess(FLAGS.train_files.split(','), FLAGS.train_batch_size, Config.n_input, Config.n_context, Config.alphabet, hdf5_cache_path=FLAGS.train_cached_features_path) train_set = DataSet(train_data, FLAGS.train_batch_size, limit=FLAGS.limit_train, next_index=lambda i: coord.get_next_index('train')) # Reading validation set dev_data = preprocess(FLAGS.dev_files.split(','), FLAGS.dev_batch_size, Config.n_input, Config.n_context, Config.alphabet, hdf5_cache_path=FLAGS.dev_cached_features_path) dev_set = DataSet(dev_data, FLAGS.dev_batch_size, limit=FLAGS.limit_dev, next_index=lambda i: coord.get_next_index('dev')) # Combining all sets to a multi set model feeder model_feeder = ModelFeeder(train_set, dev_set, Config.n_input, Config.n_context, Config.alphabet, tower_feeder_count=len( Config.available_devices)) # Create the optimizer optimizer = create_optimizer() # Synchronous distributed training is facilitated by a special proxy-optimizer if not server is None: optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.replicas_to_agg, total_num_replicas=FLAGS.replicas) # Get the data_set specific graph end-points gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) # Add summaries of all variables and gradients to log log_grads_and_vars(avg_tower_gradients) # Op to merge all summaries for the summary hook merge_all_summaries_op = tf.summary.merge_all() # These are saved on every step step_summaries_op = tf.summary.merge_all('step_summaries') step_summary_writers = { 'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Apply gradients to modify the model apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) if FLAGS.early_stop is True and not FLAGS.validation_step > 0: log_warn( 'Parameter --validation_step needs to be >0 for early stopping to work' ) class CoordHook(tf.train.SessionRunHook): r''' Embedded coordination hook-class that will use variables of the surrounding Python context. ''' def after_create_session(self, session, coord): log_debug('Starting queue runners...') model_feeder.start_queue_threads(session, coord) log_debug('Queue runners started.') def end(self, session): # Closing the data_set queues log_debug('Closing queues...') model_feeder.close_queues(session) log_debug('Queues closed.') # Telling the ps that we are done send_token_to_ps(session) # Collecting the hooks hooks = [CoordHook()] # Hook to handle initialization and queues for sync replicas. if not server is None: hooks.append(optimizer.make_session_run_hook(Config.is_chief)) # Hook to save TensorBoard summaries if FLAGS.summary_secs > 0: hooks.append( tf.train.SummarySaverHook(save_secs=FLAGS.summary_secs, output_dir=FLAGS.summary_dir, summary_op=merge_all_summaries_op)) # Hook wih number of checkpoint files to save in checkpoint_dir if FLAGS.train and FLAGS.max_to_keep > 0: saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) hooks.append( tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir, save_secs=FLAGS.checkpoint_secs, saver=saver)) no_dropout_feed_dict = { dropout_rates[0]: 0., dropout_rates[1]: 0., dropout_rates[2]: 0., dropout_rates[3]: 0., dropout_rates[4]: 0., dropout_rates[5]: 0., } # Progress Bar def update_progressbar(set_name): if not hasattr(update_progressbar, 'current_set_name'): update_progressbar.current_set_name = None if (update_progressbar.current_set_name != set_name or update_progressbar.current_job_index == update_progressbar.total_jobs): # finish prev pbar if it exists if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar: update_progressbar.pbar.finish() update_progressbar.total_jobs = None update_progressbar.current_job_index = 0 current_epoch = coord._epoch - 1 if set_name == "train": log_info('Training epoch %i...' % current_epoch) update_progressbar.total_jobs = coord._num_jobs_train else: log_info('Validating epoch %i...' % current_epoch) update_progressbar.total_jobs = coord._num_jobs_dev # recreate pbar update_progressbar.pbar = progressbar.ProgressBar( max_value=update_progressbar.total_jobs, redirect_stdout=True).start() update_progressbar.current_set_name = set_name if update_progressbar.pbar: update_progressbar.pbar.update( update_progressbar.current_job_index + 1, force=True) update_progressbar.current_job_index += 1 # Initialize update_progressbar()'s child fields to safe values update_progressbar.pbar = None # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. try: with tf.train.MonitoredTrainingSession( master='' if server is None else server.target, is_chief=Config.is_chief, hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir, save_checkpoint_secs=None, # already taken care of by a hook log_step_count_steps= 0, # disable logging of steps/s to avoid TF warning in validation sets config=Config.session_config) as session: tf.get_default_graph().finalize() try: if Config.is_chief: # Retrieving global_step from the (potentially restored) model model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train) step = session.run(global_step, feed_dict=no_dropout_feed_dict) coord.start_coordination(model_feeder, step) # Get the first job job = coord.get_job() while job and not session.should_stop(): log_debug('Computing %s...' % job) is_train = job.set_name == 'train' # The feed_dict (mainly for switching between queues) if is_train: feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } else: feed_dict = no_dropout_feed_dict # Sets the current data_set for the respective placeholder in feed_dict model_feeder.set_data_set( feed_dict, getattr(model_feeder, job.set_name)) # Initialize loss aggregator total_loss = 0.0 # Setting the training operation in case of training requested train_op = apply_gradient_op if is_train else [] # So far the only extra parameter is the feed_dict extra_params = {'feed_dict': feed_dict} step_summary_writer = step_summary_writers.get( job.set_name) # Loop over the batches for job_step in range(job.steps): if session.should_stop(): break log_debug('Starting batch...') # Compute the batch _, current_step, batch_loss, step_summary = session.run( [train_op, global_step, loss, step_summaries_op], **extra_params) # Log step summaries step_summary_writer.add_summary( step_summary, current_step) # Uncomment the next line for debugging race conditions / distributed TF log_debug('Finished batch step %d.' % current_step) # Add batch to loss total_loss += batch_loss # Gathering job results job.loss = total_loss / job.steps # Display progressbar if FLAGS.show_progressbar: update_progressbar(job.set_name) # Send the current job to coordinator and receive the next one log_debug('Sending %s...' % job) job = coord.next_job(job) if update_progressbar.pbar: update_progressbar.pbar.finish() except Exception as e: log_error(str(e)) traceback.print_exc() # Calling all hook's end() methods to end blocking calls for hook in hooks: hook.end(session) # Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit. # A rather graceful way to do this is by stopping the ps. # Only one party can send it w/o failing. if Config.is_chief: send_token_to_ps(session, kill=True) sys.exit(1) log_debug('Session closed.') except tf.errors.InvalidArgumentError as e: log_error(str(e)) log_error( 'The checkpoint in {0} does not match the shapes of the model.' ' Did you change alphabet.txt or the --n_hidden parameter' ' between train runs using the same checkpoint dir? Try moving' ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) sys.exit(1) # Stopping the coordinator coord.stop()
def train(server=None): r''' Trains the network on a given server of a cluster. If no server provided, it performs single process training. ''' # The transfer learning approach here need us to supply the layers which we # want to exclude from the source model. # Say we want to exclude all layers except for the first one, we can use this: # # drop_source_layers=['2', '3', 'lstm', '5', '6'] # # If we want to use all layers from the source model except the last one, we use this: # # drop_source_layers=['6'] # drop_source_layers = ['2', '3', 'lstm', '5', '6'][-int(FLAGS.drop_source_layers):] # Initializing and starting the training coordinator coord = TrainingCoordinator(Config.is_chief) coord.start() # Create a variable to hold the global_step. # It will automagically get incremented by the optimizer. global_step = tf.Variable(0, trainable=False, name='global_step') dropout_rates = [ tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6) ] # Reading training set train_data = preprocess(FLAGS.train_files.split(','), FLAGS.train_batch_size, Config.n_input, Config.n_context, Config.alphabet, hdf5_cache_path=FLAGS.train_cached_features_path) train_set = DataSet(train_data, FLAGS.train_batch_size, limit=FLAGS.limit_train, next_index=lambda i: coord.get_next_index('train')) # Reading validation set dev_data = preprocess(FLAGS.dev_files.split(','), FLAGS.dev_batch_size, Config.n_input, Config.n_context, Config.alphabet, hdf5_cache_path=FLAGS.dev_cached_features_path) dev_set = DataSet(dev_data, FLAGS.dev_batch_size, limit=FLAGS.limit_dev, next_index=lambda i: coord.get_next_index('dev')) # Combining all sets to a multi set model feeder model_feeder = ModelFeeder(train_set, dev_set, Config.n_input, Config.n_context, Config.alphabet, tower_feeder_count=len( Config.available_devices)) # Create the optimizer optimizer = create_optimizer() # Synchronous distributed training is facilitated by a special proxy-optimizer if not server is None: optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.replicas_to_agg, total_num_replicas=FLAGS.replicas) # Get the data_set specific graph end-points gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates, drop_source_layers) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) # Add summaries of all variables and gradients to log log_grads_and_vars(avg_tower_gradients) # Op to merge all summaries for the summary hook merge_all_summaries_op = tf.summary.merge_all() # These are saved on every step step_summaries_op = tf.summary.merge_all('step_summaries') step_summary_writers = { 'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Apply gradients to modify the model apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) if FLAGS.early_stop is True and not FLAGS.validation_step > 0: log_warn( 'Parameter --validation_step needs to be >0 for early stopping to work' ) class CoordHook(tf.train.SessionRunHook): r''' Embedded coordination hook-class that will use variables of the surrounding Python context. ''' def after_create_session(self, session, coord): log_debug('Starting queue runners...') model_feeder.start_queue_threads(session, coord) log_debug('Queue runners started.') def end(self, session): # Closing the data_set queues log_debug('Closing queues...') model_feeder.close_queues(session) log_debug('Queues closed.') # Telling the ps that we are done send_token_to_ps(session) # Collecting the hooks hooks = [CoordHook()] # Hook to handle initialization and queues for sync replicas. if not server is None: hooks.append(optimizer.make_session_run_hook(Config.is_chief)) # Hook to save TensorBoard summaries if FLAGS.summary_secs > 0: hooks.append( tf.train.SummarySaverHook(save_secs=FLAGS.summary_secs, output_dir=FLAGS.summary_dir, summary_op=merge_all_summaries_op)) # Hook wih number of checkpoint files to save in checkpoint_dir if FLAGS.train and FLAGS.max_to_keep > 0: saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) hooks.append( tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir, save_secs=FLAGS.checkpoint_secs, saver=saver)) no_dropout_feed_dict = { dropout_rates[0]: 0., dropout_rates[1]: 0., dropout_rates[2]: 0., dropout_rates[3]: 0., dropout_rates[4]: 0., dropout_rates[5]: 0., } # Progress Bar def update_progressbar(set_name): if not hasattr(update_progressbar, 'current_set_name'): update_progressbar.current_set_name = None if (update_progressbar.current_set_name != set_name or update_progressbar.current_job_index == update_progressbar.total_jobs): # finish prev pbar if it exists if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar: update_progressbar.pbar.finish() update_progressbar.total_jobs = None update_progressbar.current_job_index = 0 current_epoch = coord._epoch - 1 sufix = "graph_noisySVA_CV_2layers_" checkpoint_stash = "/docker_files/ckpt_stash/" checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint_path = checkpoint.model_checkpoint_path ckpt_dest_name = sufix + str(current_epoch - 118) + "_eph" str_to_replace = "s/" + checkpoint_path.split( '/')[-1] + "/" + ckpt_dest_name + "/" subprocess.Popen( ["cp", checkpoint_path + ".meta", checkpoint_stash]) #pdb.set_trace() subprocess.Popen([ "rename", str_to_replace, checkpoint_stash + checkpoint_path.split('/')[-1] + ".meta" ]) subprocess.Popen([ "cp", checkpoint_path + ".data-00000-of-00001", checkpoint_stash ]) subprocess.Popen([ "rename", str_to_replace, checkpoint_stash + checkpoint_path.split('/')[-1] + ".data-00000-of-00001" ]) subprocess.Popen( ["cp", checkpoint_path + ".index", checkpoint_stash]) subprocess.Popen([ "rename", str_to_replace, checkpoint_stash + checkpoint_path.split('/')[-1] + ".index" ]) #HERE if set_name == "train": log_info('Training epoch %i...' % current_epoch) update_progressbar.total_jobs = coord._num_jobs_train else: log_info('Validating epoch %i...' % current_epoch) update_progressbar.total_jobs = coord._num_jobs_dev # recreate pbar update_progressbar.pbar = progressbar.ProgressBar( max_value=update_progressbar.total_jobs, redirect_stdout=True).start() update_progressbar.current_set_name = set_name if update_progressbar.pbar: update_progressbar.pbar.update( update_progressbar.current_job_index + 1, force=True) update_progressbar.current_job_index += 1 # Initialize update_progressbar()'s child fields to safe values update_progressbar.pbar = None ### TRANSFER LEARNING ### def init_fn(scaffold, session): if FLAGS.source_model_checkpoint_dir: drop_source_layers.append('layer_6') print('Initializing from', FLAGS.source_model_checkpoint_dir) ckpt = tf.train.load_checkpoint(FLAGS.source_model_checkpoint_dir) variables = list(ckpt.get_variable_to_shape_map().keys()) for v in tf.global_variables(): if not any(layer in v.op.name for layer in drop_source_layers): #if not v.name.count('b6') or not v.name.count('h6') or not v.name.count('raw_logits'): with open("/data/german_DS/deepspeech-german/nodes.txt", "w") as nodetxtfile: print('Loading', v.op.name) nodetxtfile.write(v.op.name) v.load(ckpt.get_tensor(v.op.name), session=session) scaffold = tf.train.Scaffold( init_op=tf.variables_initializer([ v for v in tf.global_variables() if any(layer in v.op.name for layer in drop_source_layers) ] #or v.name.count('b6')] ), init_fn=init_fn) ### TRANSFER LEARNING ### pdb.set_trace() # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. try: with tf.train.MonitoredTrainingSession( master='' if server is None else server.target, is_chief=Config.is_chief, hooks=hooks, scaffold=scaffold, # transfer-learning checkpoint_dir=FLAGS.checkpoint_dir, save_checkpoint_secs=None, # already taken care of by a hook log_step_count_steps= 0, # disable logging of steps/s to avoid TF warning in validation sets config=Config.session_config) as session: #tf.get_default_graph().finalize() #do_export = False try: if Config.is_chief: # Retrieving global_step from the (potentially restored) model model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train) step = session.run(global_step, feed_dict=no_dropout_feed_dict) coord.start_coordination(model_feeder, step) #if do_export: #export(session) #print("########INDISE EXPORT###########") #do_export = True # Get the first job job = coord.get_job() while job and not session.should_stop(): log_debug('Computing %s...' % job) is_train = job.set_name == 'train' # The feed_dict (mainly for switching between queues) if is_train: feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } else: feed_dict = no_dropout_feed_dict # Sets the current data_set for the respective placeholder in feed_dict model_feeder.set_data_set( feed_dict, getattr(model_feeder, job.set_name)) # Initialize loss aggregator total_loss = 0.0 # Setting the training operation in case of training requested train_op = apply_gradient_op if is_train else [] # So far the only extra parameter is the feed_dict extra_params = {'feed_dict': feed_dict} step_summary_writer = step_summary_writers.get( job.set_name) # Loop over the batches for job_step in range(job.steps): if session.should_stop(): break log_debug('Starting batch...') # Compute the batch _, current_step, batch_loss, step_summary = session.run( [train_op, global_step, loss, step_summaries_op], **extra_params) # Log step summaries step_summary_writer.add_summary( step_summary, current_step) # Uncomment the next line for debugging race conditions / distributed TF log_debug('Finished batch step %d.' % current_step) # Add batch to loss total_loss += batch_loss # Gathering job results job.loss = total_loss / job.steps # Display progressbar if FLAGS.show_progressbar: update_progressbar(job.set_name) # Send the current job to coordinator and receive the next one log_debug('Sending %s...' % job) job = coord.next_job(job) if update_progressbar.pbar: update_progressbar.pbar.finish() #export() #mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} #saver = tf.train.Saver(mapping) #def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): # freeze_graph.freeze_graph_with_def_protos( # input_graph_def=session.graph_def, # input_saver_def=saver.as_saver_def(), # input_checkpoint=checkpoint_path, # output_node_names=output_node_names, # restore_op_name=None, # filename_tensor_name=None, # output_graph=output_file, # clear_devices=False, # variable_names_blacklist=variables_blacklist, # initializer_nodes='') #output_graph_path = "output_graph.pb" #do_graph_freeze(output_file=output_graph_path, output_node_names='logits,initialize_state', variables_blacklist='previous_state_c,previous_state_h') except Exception as e: log_error(str(e)) traceback.print_exc() # Calling all hook's end() methods to end blocking calls for hook in hooks: hook.end(session) # Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit. # A rather graceful way to do this is by stopping the ps. # Only one party can send it w/o failing. if Config.is_chief: send_token_to_ps(session, kill=True) sys.exit(1) log_debug('Session closed.') except tf.errors.InvalidArgumentError as e: log_error(str(e)) log_error( 'The checkpoint in {0} does not match the shapes of the model.' ' Did you change alphabet.txt or the --n_hidden parameter' ' between train runs using the same checkpoint dir? Try moving' ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) sys.exit(1) # Stopping the coordinator coord.stop()