def wait_for_next_checkpoint(log_dir, last_checkpoint=None, seconds_to_sleep=1, timeout=20): """Blocking wait until next checkpoint is written to logdir. Can timeout at regular intervals to log a timeout warning (a good indicator the thread is still alive). Args: log_dir: The directory in which checkpoints are saved. last_checkpoint: The last checkpoint path used or None if we're expecting a checkpoint for the first time. seconds_to_sleep: The number of seconds to sleep for before looking for a new checkpoint. timeout: The maximum amount of time to wait before printing timeout warning and checking for a new checkpoint. If left as None, then the thread will wait indefinitely. Returns: next_checkpoint filename. """ while True: logging.info('Waiting for next policy checkpoint...') next_checkpoint = contrib_training.wait_for_new_checkpoint( log_dir, last_checkpoint, seconds_to_sleep=seconds_to_sleep, timeout=timeout) if next_checkpoint is None: logging.warn('Timeout waiting for checkpoint, trying again...') elif next_checkpoint != last_checkpoint: # Found a new checkpoint. logging.warn('Found a new checkpoint ("%s").', next_checkpoint) break else: logging.warn('No new checkpoint found, trying again...') return next_checkpoint
def evaluate(master, model_fn, data_fn, additional_trial_info, model_dir, preprocess_examples, hparams, name, num_steps=None): """Evaluation loop.""" estimator = create_estimator( model_fn=model_fn, model_dir=model_dir, master=master, hparams=hparams) transcription_data_base = functools.partial( data_fn, preprocess_examples=preprocess_examples, is_training=False) if num_steps is None: transcription_data = functools.partial( transcription_data_base, shuffle_examples=False, skip_n_initial_records=0) else: # If num_steps is specified, we will evaluate only a subset of the data. # # The following is a hack that works around the problems of not being able # to determine the number of records in a given TFRecord shard without # reading the whole thing and not being able to persist a tf.data.Dataset # session across multiple estimator evaluate calls. # # This code tries to select a different subset for every evaluation by doing # the following: # - Setting shuffle_examples=True. This shuffles not only individual # examples, but also shuffles the order in which shards are read. # - Skipping N examples before starting evaluation, where N is selected # randomly for each evaluation run. This provides a different starting # offset. # In order to skip a random number of records, we need to provide an upper # bound that will still let us run num_steps evaluation steps before running # out of data. The following code does a one-time check on startup to see # if there are up to num_steps * 5 records available, which would allow # a maximum skip range of [0, num_steps*4]. records_to_check = num_steps * 5 tf.logging.info('Checking for at least %d records...', records_to_check) records_available = 0 with tf.Graph().as_default(): record_check_params = copy.deepcopy(hparams) record_check_params.batch_size = 1 iterator = transcription_data_base( params=record_check_params, shuffle_examples=False, skip_n_initial_records=0, ).make_initializable_iterator() next_record = iterator.get_next() with tf.Session() as sess: sess.run(iterator.initializer) try: for i in range(records_to_check): del i sess.run(next_record) records_available += 1 if records_available % 10 == 0: tf.logging.info('Found %d records...', records_available) except tf.errors.OutOfRangeError: pass # Determine max number of records we could skip and still have num_steps # records remaining. max_records_to_skip = max(0, records_available - num_steps) tf.logging.info('Found at least %d records. ' 'Will skip a maximum of %d records during eval runs ' 'in order to support %d evaluation steps.', records_available, max_records_to_skip, num_steps) # Since we're doing a limited number of steps, we should shuffle the # examples we're evaluating so each evaluation is over a different portion # of the dataset. def transcription_data(params, *args, **kwargs): assert not args skip_n_initial_records = random.randint(0, max_records_to_skip) tf.logging.info('Skipping %d initial record(s)', skip_n_initial_records) return transcription_data_base( params=params, shuffle_examples=True, skip_n_initial_records=skip_n_initial_records, **kwargs) _trial_summary( hparams=hparams, model_dir=model_dir, output_dir=estimator.eval_dir(name), additional_trial_info=additional_trial_info) checkpoint_path = None while True: checkpoint_path = contrib_training.wait_for_new_checkpoint( model_dir, last_checkpoint=checkpoint_path) estimator.evaluate(input_fn=transcription_data, steps=num_steps, checkpoint_path=checkpoint_path, name=name)
# only. Since learning goes very fast, we save often. if i % params['eval_steps'] == 0 or i == params['steps']: saver.save(sess, args.model_dir + '/model.ckpt', global_step=i) else: print('Evaluating on %s' % params['partition']) # For each checkpoint the entire dataset is evaluated. steps_per_eval = params['%s_size' % params['partition']] checkpoint = None # Basic session since we will only manually save summaries. with tf.Session() as sess: coord = tf.train.Coordinator() # Queue runners will take care of reading data in seperate threads. threads = tf.train.start_queue_runners(coord=coord) while True: checkpoint = wait_for_new_checkpoint(args.model_dir, checkpoint, seconds_to_sleep=1, timeout=1200) if checkpoint is None: print('No checkpoint found for 20 min, exiting evaluation.') break # Init for variables that are not part of checkpoint, # in this case the ones used for metrics. sess.run(init) # Restore a checkpoint saved by the training run. saver.restore(sess, checkpoint) # Update the metrics for every element in the dataset. batch_steps = int(np.ceil(steps_per_eval/float(params['read_batch']))) for i in range(batch_steps): sess.run([eval_update]) # Get the resulting metrics. cur_step, cur_reward, cur_summary = sess.run([global_step, mean_reward, merged_summary])