コード例 #1
0
def _augment_and_evaluate_impl(spec, config, model_dir, epochs_per_eval=5):
  """Augment and evaluate implementation, see augment_and_evaluate docstring."""
  input_augment, input_test = [
      cifar.CIFARInput(m, config)
      for m in ['augment', 'test']]
  estimator = _create_estimator(spec, config, model_dir,
                                input_augment.num_images)

  if config['train_seconds'] > 0.0:
    timing = training_time.limit(config['train_seconds'])
  else:
    timing = training_time.limit(None)

  steps_per_epoch = input_augment.num_images / config['batch_size']   # float
  ckpt = tf.train.latest_checkpoint(model_dir)
  if not ckpt:
    current_step = 0
  else:
    current_step = int(ckpt.split('-')[-1])
  max_steps = int(config['train_epochs'] * steps_per_epoch)

  while current_step < max_steps:
    next_step = current_step + int(epochs_per_eval * steps_per_epoch)
    next_step = min(next_step, max_steps)
    estimator.train(
        input_fn=input_augment.input_fn,
        max_steps=next_step,
        hooks=[timing.train_hook],
        saving_listeners=[timing.saving_listener])
    current_step = next_step

    test_accuracy = _evaluate(estimator, input_test, config)

  metadata = {
      'trainable_params': _get_param_count(model_dir),
      'test_accuracy': test_accuracy,
  }

  return metadata
コード例 #2
0
  def run(self):
    """Runs training and evaluation."""
    attempts = 0
    while True:
      # Delete everything in the model dir at the start of each attempt
      try:
        tf.gfile.DeleteRecursively(self.model_dir)
      except tf.errors.NotFoundError:
        pass
      tf.gfile.MakeDirs(self.model_dir)

      try:
        # Train
        if self.config['train_seconds'] > 0.0:
          timing = training_time.limit(self.config['train_seconds'])
        else:
          timing = training_time.limit(None)

        evaluations = map(float, self.config['intermediate_evaluations'])
        if not evaluations or evaluations[-1] != 1.0:
          evaluations.append(1.0)
        assert evaluations == sorted(evaluations)

        evaluation_results = []
        start_time = time.time()

        # Train for 1 step with 0 LR to initialize the weights, then evaluate
        # once at the start for completeness, accuracies expected to be around
        # random selection. Note that batch norm moving averages change during
        # the step but the trainable weights do not.
        self.estimator.train(
            input_fn=self.input_train.input_fn,
            max_steps=1,
            hooks=[timing.train_hook],
            saving_listeners=[timing.saving_listener])
        evaluation_results.append(self._evaluate_all(0.0, 0))

        for next_evaluation in evaluations:
          epoch = next_evaluation * self.config['train_epochs']
          train_steps = int(epoch * self.input_train.num_images /
                            self.config['batch_size'])
          self.estimator.train(
              input_fn=self.input_train.input_fn,
              max_steps=train_steps,
              hooks=[timing.train_hook],
              saving_listeners=[timing.saving_listener])

          evaluation_results.append(self._evaluate_all(epoch, train_steps))

        all_time = time.time() - start_time
        break     # Break from retry loop on success
      except VALID_EXCEPTIONS as e:   # pylint: disable=catching-non-exception
        attempts += 1
        tf.logging.warning(str(e))
        if attempts >= self.config['max_attempts']:
          raise AbortError(str(e))

    metadata = {
        'trainable_params': _get_param_count(self.model_dir),
        'total_time': all_time,   # includes eval and other metric time
        'evaluation_results': evaluation_results,
    }

    return metadata
コード例 #3
0
    def run(self):
        """Runs training and evaluation."""
        attempts = 0
        tf.compat.v1.logging.set_verbosity('WARN')
        while True:
            # Delete everything in the model dir at the start of each attempt
            try:
                tf.io.gfile.rmtree(self.model_dir)
            except tf.errors.NotFoundError:
                pass
            tf.io.gfile.makedirs(self.model_dir)

            try:
                # Train
                if self.config['train_seconds'] > 0.0:
                    timing = training_time.limit(self.config['train_seconds'])
                else:
                    timing = training_time.limit(None)

                evaluation_results = []
                start_time = time.time()

                # Train for 1 step with 0 LR to initialize the weights, then evaluate
                # once at the start for completeness, accuracies expected to be around
                # random selection. Note that batch norm moving averages change during
                # the step but the trainable weights do not.
                self.estimator.train(input_fn=self.input_train.input_fn,
                                     max_steps=1,
                                     hooks=[timing.train_hook],
                                     saving_listeners=[timing.saving_listener])
                evaluation_results.append(self._evaluate_all(0.0, 0))
                print('evaluated after epoch zero')

                epochs = self.config['train_epochs']
                steps_per_epoch = int(self.input_train.num_images /
                                      self.config['batch_size'])
                total_steps = epochs * steps_per_epoch

                for e in range(epochs):
                    steps_so_far = e * steps_per_epoch

                    self.estimator.train(
                        input_fn=self.input_train.input_fn,
                        steps=steps_per_epoch,
                        hooks=[timing.train_hook],
                        saving_listeners=[timing.saving_listener])

                    result = self._evaluate_all(e, steps_so_far)
                    print('colin, evaluation', e, 'val', result['validation_loss'], result['validation_accuracy'], \
                          'train', result['train_loss'], result['train_accuracy'])
                    evaluation_results.append(result)

                all_time = time.time() - start_time
                break  # Break from retry loop on success
            except VALID_EXCEPTIONS as e:  # pylint: disable=catching-non-exception
                attempts += 1
                tf.compat.v1.logging.warning(str(e))
                if attempts >= self.config['max_attempts']:
                    raise AbortError(str(e))

        metadata = {
            'trainable_params': _get_param_count(self.model_dir),
            'total_time': all_time,  # includes eval and other metric time
            'evaluation_results': evaluation_results,
        }

        return metadata