def test_get_dataset_raises_error_for_empty_name(self):
   dataset_config_pbtext_filename = _test_dataset_config(
       'test_get_dataset_raises_error_for_empty_name.pbtxt')
   with six.assertRaisesRegex(self, ValueError,
                              'dataset_config needs to have a name'):
     data_providers.get_input_fn_from_dataset(
         dataset_config_pbtext_filename, mode=tf.estimator.ModeKeys.EVAL)
 def test_get_dataset_raises_error_for_empty_data_split(self):
   dataset_config_pbtext_filename = _test_dataset_config(
       'test_get_dataset_raises_error_for_empty_data_split.pbtxt',
       name='some_dataset_name')
   expected_exception_message = (
       'The dataset in the config {} does not '
       'have a tfrecord_path.'.format(dataset_config_pbtext_filename))
   with six.assertRaisesRegex(self, ValueError, expected_exception_message):
     data_providers.get_input_fn_from_dataset(
         dataset_config_pbtext_filename, mode=tf.estimator.ModeKeys.EVAL)
Example #3
0
 def test_get_dataset_raises_error_for_empty_num_examples(self):
     dataset_config_pbtext_filename = _test_dataset_config(
         'test_get_dataset_raises_error_for_empty_num_examples.pbtxt',
         name='some_dataset_name',
         tfrecord_path='/path/to/dataset')
     expected_exception_message = (
         'The dataset in the config {} does not have '
         'a num_examples.'.format(dataset_config_pbtext_filename))
     with self.assertRaisesRegexp(ValueError, expected_exception_message):
         data_providers.get_input_fn_from_dataset(
             dataset_config_pbtext_filename,
             mode=tf.estimator.ModeKeys.EVAL)
Example #4
0
def run(target, unused_is_chief, device_fn, use_tpu):
  """Run training.

  Args:
     target: The target of the TensorFlow standard server to use. Can be the
       empty string to run locally using an inprocess server.
     device_fn: Device function used to assign ops to devices.
     use_tpu: turn on tpu code path.
  """
  if not FLAGS.dataset_config_pbtxt:
    logging.error('Need to specify --dataset_config_pbtxt')
    return

  g = tf.Graph()
  with g.as_default():
    with tf.device(device_fn):
      # If ps_tasks is zero, the local device is used. When using multiple
      # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
      # across the different devices.

      tf_dataset = data_providers.get_input_fn_from_dataset(
          dataset_config_filename=FLAGS.dataset_config_pbtxt,
          mode=tf.estimator.ModeKeys.TRAIN,
          max_examples=FLAGS.max_examples,
          use_tpu=use_tpu)
      model = modeling.get_model(FLAGS.model_name)
      logging.info('Running training on %s with model %s and tpu %s',
                   tf_dataset, FLAGS.model_name, use_tpu)

      batches_per_epoch = tf_dataset.num_examples // FLAGS.batch_size
      logging.info('Batches per epoch %s', batches_per_epoch)
      params = dict(batches_per_epoch=batches_per_epoch,)
      estimator = model.make_estimator(
          batch_size=FLAGS.batch_size,
          model_dir=FLAGS.train_dir,
          params=params,
          use_tpu=use_tpu,
          master=target,
          start_from_checkpoint=FLAGS.start_from_checkpoint,
      )

      training_hooks = None
      if FLAGS.use_early_stopping:
        # redacted
        raise ValueError('Currently not implemented.')

      estimator.train(
          input_fn=tf_dataset,
          max_steps=FLAGS.number_of_steps,
          hooks=training_hooks)
  def test_get_dataset(self):
    dataset_config_pbtext_filename = _test_dataset_config(
        'golden.dataset_config.pbtxt',
        name='some_dataset_name',
        tfrecord_path='/dev/null',
        num_examples=1000)
    ds = data_providers.get_input_fn_from_dataset(
        dataset_config_pbtext_filename,
        mode=tf.estimator.ModeKeys.EVAL,
        tensor_shape=[3, 4, dv_constants.PILEUP_NUM_CHANNELS])

    self.assertEqual('some_dataset_name', ds.name)
    self.assertEqual('/dev/null', ds.input_file_spec)
    self.assertEqual(1000, ds.num_examples)
    self.assertEqual([3, 4, dv_constants.PILEUP_NUM_CHANNELS], ds.tensor_shape)
  def test_reading_sharded_dataset(self, compressed_inputs, use_tpu):
    golden_dataset = make_golden_dataset(compressed_inputs, use_tpu=use_tpu)
    n_shards = 3
    sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards))
    tfrecord.write_tfrecords(
        tfrecord.read_tfrecords(golden_dataset.input_file_spec), sharded_path)

    config_file = _test_dataset_config(
        'test_sharded.pbtxt',
        name='sharded_test',
        tfrecord_path=sharded_path,
        num_examples=golden_dataset.num_examples)

    self.assertTfDataSetExamplesMatchExpected(
        data_providers.get_input_fn_from_dataset(
            config_file, mode=tf.estimator.ModeKeys.EVAL),
        golden_dataset,
        # workaround_list_files is needed because wildcards, and so sharded
        # files, are nondeterministicly ordered (for now).
        workaround_list_files=True,
    )
def eval_loop(master,
              dataset_config_pbtxt,
              checkpoint_dir,
              model_name,
              batch_size,
              max_examples,
              eval_name,
              max_evaluations,
              use_tpu=False):
    """Evaluate incoming checkpoints, until the specified end."""
    logging.info('Running fixed eval for: %s', dataset_config_pbtxt)

    tf_dataset = data_providers.get_input_fn_from_dataset(
        dataset_config_filename=dataset_config_pbtxt,
        mode=tf.estimator.ModeKeys.EVAL,
        use_tpu=use_tpu,
    )

    best_ckpt = None
    ckpt_metric = FLAGS.best_checkpoint_metric
    ckpt_metric_increasing = ckpt_metric in increasing_metrics

    model = modeling.get_model(model_name)
    logging.info('Running evaluations on %s with model %s', tf_dataset, model)

    # Compute when to stop reading, in terms of batches.
    num_examples = tf_dataset.num_examples
    if max_examples is not None:
        num_examples = min(max_examples, num_examples)
    num_batches = num_examples // batch_size
    num_samples = batch_size * num_batches
    logging.info(
        'Dataset has %s samples, doing eval over %s; '
        'max_examples is %s, num examples to be used %s; num_batches is %s',
        tf_dataset.num_examples, num_samples, max_examples, num_examples,
        num_batches)

    # This loads EMA variables.
    eval_hooks = [h(checkpoint_dir) for h in model.session_eval_hooks()]

    classifier = model.make_estimator(batch_size=batch_size,
                                      model_dir=checkpoint_dir,
                                      use_tpu=use_tpu,
                                      master=master)

    def terminate_eval():
        logging.info('Terminating eval after %d seconds of no checkpoints',
                     FLAGS.eval_timeout)
        return True

    # Run evaluation when there's a new checkpoint
    num_evaluations = 0
    for ckpt in checkpoints_iterator(
            checkpoint_dir=checkpoint_dir,
            min_interval_secs=FLAGS.min_eval_interval_s,
            timeout=FLAGS.eval_timeout,
            timeout_fn=terminate_eval):

        logging.info('Starting to evaluate.')

        # For each step, calls input_fn, which returns one batch of data.
        # Evaluates until either steps batches are processed, or input_fn raises an
        # end-of-input exception (OutOfRangeError or StopIteration).
        eval_results = classifier.evaluate(input_fn=tf_dataset,
                                           steps=num_batches,
                                           hooks=eval_hooks,
                                           checkpoint_path=ckpt,
                                           name=eval_name)
        logging.info('Eval results: %s', eval_results)

        # Track best checkpoint seen so far, measured by ckpt_metric.
        if not best_ckpt:
            # If the training jobs died, pick up where we left off.
            try:
                best_metrics = read_metrics(ckpt, eval_name,
                                            'best_checkpoint.metrics')
                logging.info('Found existing best_checkpoint: %s',
                             best_metrics)
                best_ckpt = (best_metrics, ckpt)
            except NotFoundError:
                logging.info('best_checkpoint file does not exist.')
                best_ckpt = (eval_results, ckpt)
                _write_best_checkpoint(ckpt, eval_results, eval_name)
        if ((ckpt_metric_increasing
             and eval_results[ckpt_metric] > best_ckpt[0][ckpt_metric]) or
            (not ckpt_metric_increasing
             and eval_results[ckpt_metric] < best_ckpt[0][ckpt_metric])):
            best_ckpt = (eval_results, ckpt)
            _write_best_checkpoint(ckpt, eval_results, eval_name)

        _write_checkpoint_metrics(ckpt, eval_results, eval_name)

        # An alternative strategy might check step-number-of-ckpt >= train_steps.
        num_evaluations += 1
        if max_evaluations is not None and num_evaluations >= max_evaluations:
            logging.info('Evaluation finished after %d evaluations',
                         num_evaluations)
            break

    return
Example #8
0
def eval_loop(master,
              dataset_config_pbtxt,
              checkpoint_dir,
              model_name,
              batch_size,
              max_examples,
              eval_name,
              max_evaluations,
              use_tpu=False):
    """Evaluate incoming checkpoints, until the specified end."""
    logging.info('Running fixed eval for: %s', dataset_config_pbtxt)

    tf_dataset = data_providers.get_input_fn_from_dataset(
        dataset_config_filename=dataset_config_pbtxt,
        mode=tf.estimator.ModeKeys.EVAL,
        use_tpu=use_tpu,
    )

    model = modeling.get_model(model_name)
    logging.info('Running evaluations on %s with model %s', tf_dataset, model)

    # Compute when to stop reading, in terms of batches.
    num_batches = min(max_examples, tf_dataset.num_examples) // batch_size
    num_samples = batch_size * num_batches
    logging.info(
        'Dataset has %d samples, doing eval over %d; '
        'max_examples is %d, num_batches is %d', tf_dataset.num_examples,
        num_samples, max_examples, num_batches)
    batches_per_epoch = tf_dataset.num_examples / batch_size

    # This loads EMA variables.
    eval_hooks = [h(checkpoint_dir) for h in model.session_eval_hooks()]

    classifier = model.make_estimator(
        batch_size=batch_size,
        model_dir=checkpoint_dir,
        params={'batches_per_epoch': batches_per_epoch},
        use_tpu=use_tpu,
        master=master,
    )

    def terminate_eval():
        logging.info('Terminating eval after %d seconds of no checkpoints',
                     FLAGS.eval_timeout)
        return True

    # Run evaluation when there's a new checkpoint
    num_evaluations = 0
    for ckpt in checkpoints_iterator(
            checkpoint_dir=checkpoint_dir,
            min_interval_secs=FLAGS.min_eval_interval_s,
            timeout=FLAGS.eval_timeout,
            timeout_fn=terminate_eval):

        logging.info('Starting to evaluate.')

        # For each step, calls input_fn, which returns one batch of data.
        # Evaluates until either steps batches are processed, or input_fn raises an
        # end-of-input exception (OutOfRangeError or StopIteration).
        eval_results = classifier.evaluate(input_fn=tf_dataset,
                                           steps=num_batches,
                                           hooks=eval_hooks,
                                           checkpoint_path=ckpt,
                                           name=eval_name)
        logging.info('Eval results: %s', eval_results)

        _write_checkpoint_metrics(ckpt, eval_results, eval_name)

        # An alternative strategy might check step-number-of-ckpt >= train_steps.
        num_evaluations += 1
        if max_evaluations is not None and num_evaluations >= max_evaluations:
            logging.info('Evaluation finished after %d evaluations',
                         num_evaluations)
            break

    return
Example #9
0
def run(target, unused_is_chief, device_fn, use_tpu):
  """Run training.

  Args:
     target: The target of the TensorFlow standard server to use. Can be the
       empty string to run locally using an inprocess server.
     device_fn: Device function used to assign ops to devices.
     use_tpu: turn on tpu code path.
  """
  if not FLAGS.dataset_config_pbtxt:
    logging.error('Need to specify --dataset_config_pbtxt')
    return

  g = tf.Graph()
  with g.as_default():
    with tf.device(device_fn):
      # If ps_tasks is zero, the local device is used. When using multiple
      # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
      # across the different devices.

      tf_dataset = data_providers.get_input_fn_from_dataset(
          dataset_config_filename=FLAGS.dataset_config_pbtxt,
          mode=tf.estimator.ModeKeys.TRAIN,
          max_examples=FLAGS.max_examples,
          use_tpu=use_tpu)
      model = modeling.get_model(FLAGS.model_name)
      logging.info('Running training on %s with model %s and tpu %s',
                   tf_dataset, FLAGS.model_name, use_tpu)

      batches_per_epoch = tf_dataset.num_examples // FLAGS.batch_size
      logging.info('Batches per epoch %s', batches_per_epoch)
      params = dict(batches_per_epoch=batches_per_epoch,)
      estimator = model.make_estimator(
          batch_size=FLAGS.batch_size,
          model_dir=FLAGS.train_dir,
          params=params,
          use_tpu=use_tpu,
          master=target,
          start_from_checkpoint=FLAGS.start_from_checkpoint,
      )

      training_hooks = None
      if FLAGS.use_early_stopping:
        # Early stopping hook depends on existence of events directory.
        eval_dir = os.path.join(FLAGS.train_dir, FLAGS.early_stopping_directory)
        tf.gfile.MakeDirs(eval_dir)

        plateau_decrease = True
        if FLAGS.early_stopping_metric_direction == 'increase':
          plateau_decrease = False

        early_stopping_hook = metrics_hook.EarlyStoppingHook(
            events_dir=eval_dir,
            tag=FLAGS.early_stopping_tag,
            num_plateau_steps=FLAGS.early_stopping_num_plateau_steps,
            plateau_delta=FLAGS.early_stopping_plateau_delta,
            plateau_decrease=plateau_decrease,
            every_n_steps=FLAGS.early_stopping_every_n_steps)

        training_hooks = [early_stopping_hook]

      estimator.train(
          input_fn=tf_dataset,
          max_steps=FLAGS.number_of_steps,
          hooks=training_hooks)