Exemple #1
0
 def test_steps_and_saves_reloads(self):
   est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
   est.train(dummy_input_fn, steps=5)
   self.assertEqual(
       5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
   est.train(dummy_input_fn, steps=5)
   self.assertEqual(
       10, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
 def test_max_step(self):
   est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
   est.fit(dummy_input_fn, max_steps=5)
   self.assertEqual(
       5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
   est.fit(dummy_input_fn, max_steps=5)
   self.assertEqual(
       5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
Exemple #3
0
 def test_max_step(self):
   est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
   est.fit(dummy_input_fn, max_steps=5)
   self.assertEqual(
       5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
   est.fit(dummy_input_fn, max_steps=5)
   self.assertEqual(
       5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
Exemple #4
0
 def test_steps_and_saves_reloads(self):
   est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
   est.train(dummy_input_fn, steps=5)
   self.assertEqual(
       5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
   est.train(dummy_input_fn, steps=5)
   self.assertEqual(
       10, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
Exemple #5
0
def main(unused_argv):
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data_dir)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data_dir)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                   batches_per_epoch, current_step))
    #start_timestamp = time.time()
    #while current_step < FLAGS.train_steps:
    # Train for up to steps_per_eval number of steps. At the end of training, a
    # checkpoint will be written to --model_dir.
    #  next_checkpoint = min(current_step + FLAGS.steps_per_eval,
    #                        FLAGS.train_steps)
    resnet_classifier.train(input_fn=imagenet_train.input_fn,
                            max_steps=FLAGS.train_steps)
Exemple #6
0
    def train_process(self):
        """Whole train process of the TrainWorker specified in config.

        After training, the model and validation results are saved to local_worker_path and s3_path.
        """
        self._init_estimator()
        self._init_dataloader()
        logging_hook = []
        if self.horovod:
            logging_hook += [hvd.BroadcastGlobalVariablesHook(0)]
        train_steps = self.train_data.data_len
        valid_steps = self.valid_data.data_len
        if self.horovod:
            train_steps = train_steps // hvd.size()
            valid_steps = valid_steps // hvd.size()
        start_step = est._load_global_step_from_checkpoint_dir(self.get_local_worker_path())
        for i in range(self.cfg.epochs):
            logging.info('train epoch [{0}/{1}]'.format(i, self.cfg.epochs))
            current_max_step = start_step + train_steps
            start_step = current_max_step
            self.estimator.train(input_fn=self.train_data.input_fn,
                                 max_steps=current_max_step,
                                 hooks=logging_hook)
            eval_results = self.estimator.evaluate(input_fn=self.valid_data.input_fn, steps=valid_steps)
            logging.info(eval_results)
        self.save_backup(eval_results)
def run_toy_model_tpu():
  """Run a toy model on TPU."""
  iterations_per_loop = FLAGS.iterations
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  config = tpu_config.RunConfig(
      master=FLAGS.master,
      evaluation_master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=None,  # Disable the default saver
      save_checkpoints_secs=None,  # Disable the default saver
      log_step_count_steps=iterations_per_loop,
      tpu_config=tpu_config.TPUConfig(
          num_shards=mesh_shape.size,
          iterations_per_loop=iterations_per_loop,
          num_cores_per_replica=1,
          per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))
  classifier = tpu_estimator.TPUEstimator(
      use_tpu=True,
      model_fn=model_fn,
      config=config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size)
  current_step = estimator_lib._load_global_step_from_checkpoint_dir(FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
  logging.info('Current step %d', current_step)
  while current_step < FLAGS.train_steps:
    next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint,
                          FLAGS.train_steps)
    classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint)
    current_step = next_checkpoint

    tf.logging.info('Starting to evaluate.')
    eval_results = classifier.evaluate(
        input_fn=ToyModelInput(),
        steps=156)  # since we have 10000 examples and batch_size = 64 per host
    logging.info('Eval results: %s', eval_results)
Exemple #8
0
def main():

    filenames = tf.matching_files(TRAIN_DIR)

    train_dataset = InputReader(TRAIN_DIR, True)

    steps_per_epoch = NUM_TRAIN_IMAGES // BATCH_SIZE
    eval_steps = NUM_VAL_IMAGES // BATCH_SIZE

    current_step = estimator._load_global_step_from_checkpoint_dir(MODEL_DIR)
    steps_per_epoch = NUM_TRAIN_IMAGES // BATCH_SIZE

    params = {'batch_size': BATCH_SIZE}
    train_data = train_dataset(params)

    iterator = train_data.make_one_shot_iterator()
    get = iterator.get_next()

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    files = sess.run(filenames)
    print(files)
    print(get)
    i = 0
    while True:
        a, b = sess.run(get)
        print(a.shape)
        print(b.shape)
        i += 1
        print(i)

    print(train_data)
def run_evaluation_conserving_best(
        estimator, batch_size, batch_threads, dataset_factory, image_size,
        evaluation_summary_writer: MatlabEvaluationSummaryWriter):
    global best_result

    check_init_best_result(estimator)

    result = run_prediction_and_evaluation(batch_size=batch_size,
                                           batch_threads=batch_threads,
                                           dataset_factory=dataset_factory,
                                           estimator=estimator,
                                           image_size=image_size)
    global_step = _load_global_step_from_checkpoint_dir(estimator.model_dir)
    evaluation_summary_writer.write_evaluation_result(global_step=global_step,
                                                      evaluation_result=result)

    if (best_result.rank1 + best_result.mAP) < (result.rank1 + result.mAP):
        print('Current checkpoint is better => replacing the old best')

        best_prediction_directory = get_best_prediction_directory(estimator)

        if os.path.exists(best_prediction_directory):
            os.rename(best_prediction_directory, best_prediction_directory +
                      ".old")  # rename it first to give nfs time to delete
            shutil.rmtree(best_prediction_directory + ".old")

        print('source: %s' % get_prediction_directory(estimator))
        print('dest: %s' % best_prediction_directory)
        os.rename(get_prediction_directory(estimator),
                  best_prediction_directory)
Exemple #10
0
def main(argv):

    del argv

    tf.gfile.MakeDirs(os.path.join(FLAGS.model_dir))

    resolution = FLAGS.end_resolution
    initial_checkpoint = None
    while initial_checkpoint is None and resolution != 1:
        model_dir = os.path.join(FLAGS.model_dir,
                                 'resolution_' + str(resolution))
        initial_checkpoint = tf.train.latest_checkpoint(model_dir)
        resolution = resolution // 2
    if initial_checkpoint is None or resolution == 1:
        resolution = FLAGS.start_resolution
        model_dir = os.path.join(FLAGS.model_dir,
                                 'resolution_' + str(resolution))
    else:
        resolution *= 2
        model_dir = os.path.join(FLAGS.model_dir,
                                 'resolution_' + str(resolution))

    est, local_est = get_estimator(model_dir, resolution)

    current_step = estimator._load_global_step_from_checkpoint_dir(model_dir)  # pylint: disable=protected-access,line-too-long

    tf.logging.info('Starting training for %d steps, current step: %d' %
                    (FLAGS.train_steps, current_step))
    while current_step < FLAGS.train_steps:
        if current_step != 0 and current_step % FLAGS.resolution_steps == 0 and resolution != FLAGS.end_resolution:
            resolution *= 2
            tf.logging.info('Change of resolution from %d to %d' %
                            (resolution // 2, resolution))
            model_dir = os.path.join(FLAGS.model_dir,
                                     'resolution_' + str(resolution))
            change_resolution(resolution)
            est, local_est = get_estimator(model_dir, resolution)
        next_checkpoint = min(current_step + FLAGS.train_steps_per_eval,
                              FLAGS.train_steps)
        est.train(input_fn=dataset.TrainInputFunction(FLAGS.noise_dim,
                                                      resolution, 'NHWC'),
                  max_steps=next_checkpoint)
        current_step = next_checkpoint
        tf.logging.info('Finished training step %d' % current_step)

        if FLAGS.eval_loss:
            metrics = est.evaluate(
                input_fn=dataset.TrainInputFunction(FLAGS.noise_dim,
                                                    resolution, 'NHWC'),
                steps=FLAGS.num_eval_images // FLAGS.batch_size)
            tf.logging.info('Finished evaluating')
            tf.logging.info(metrics)

        generated_iter = local_est.predict(input_fn=noise_input_fn)
        images = [p['generated_images'][:, :, :] for p in generated_iter]
        filename = os.path.join(
            FLAGS.model_dir,
            '%s-%s.png' % (str(current_step).zfill(5), 'x' + str(resolution)))
        utils.write_images(images, filename, 'NHWC')
        tf.logging.info('Finished generating images')
Exemple #11
0
def train_and_maybe_evaluate(model_est, imagenet_train, imagenet_eval, params):
    """Trains the model and maybe run evaluation when the mode flag is set to 'train_and_eval'
    Args:
        model_est: `TPUEstimator` instance for the discovered model
        imagenet_train: Input pipeline for the training set
        imagenet_eval: Input pipeline for the validation set
        params: Dictionary containing parameters
    """
    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

    tf.logging.info(
        'Training for %d steps (%.2f epochs in total). Current'
        ' step %d.', FLAGS.train_steps,
        FLAGS.train_steps / params['steps_per_epoch'], current_step)

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == 'train':
        hooks = []
        if FLAGS.use_async_checkpointing:
            hooks.append(
                async_checkpoint.AsyncCheckpointSaverHook(
                    checkpoint_dir=FLAGS.model_dir,
                    save_steps=max(100, FLAGS.iterations_per_loop)))
        model_est.train(input_fn=imagenet_train.input_fn,
                        max_steps=FLAGS.train_steps,
                        hooks=hooks)

    else:
        while current_step < FLAGS.train_steps:
            # Train for up to steps_per_eval number of steps.
            # At the end of training, a checkpoint will be written to --model_dir.
            next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                  FLAGS.train_steps)
            model_est.train(input_fn=imagenet_train.input_fn,
                            max_steps=int(next_checkpoint))
            current_step = next_checkpoint

            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                next_checkpoint, int(time.time() - start_timestamp))

            # Evaluate the model on the most recent model in --model_dir.
            # Since evaluation happens in batches of --eval_batch_size, some images
            # may be excluded modulo the batch size. As long as the batch size is
            # consistent, the evaluated images are also consistent.
            tf.logging.info('Starting to evaluate.')
            eval_results = model_est.evaluate(input_fn=imagenet_eval.input_fn,
                                              steps=FLAGS.num_eval_images //
                                              FLAGS.eval_batch_size)
            tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                            eval_results)

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                        FLAGS.train_steps, elapsed_time)
        if FLAGS.export_dir:
            export(model_est, FLAGS.export_dir)
Exemple #12
0
def main(argv):
    del argv

    global is_bias
    global noise_dim
    is_bias = True if FLAGS.condition == 'bias' else False
    noise_dim = 100 if is_bias else 90

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        keep_checkpoint_max=None,
        tpu_config=tf.contrib.tpu.TPUConfig(
            num_shards=FLAGS.num_shards,
            iterations_per_loop=FLAGS.iterations_per_loop))

    # Set module-level global variable so that model_fn and input_fn can be
    # identical for each different kind of dataset and model
    global dataset, model
    dataset = tpu_input
    model = tpu_model

    # TPU-based estimator used for TRAIN and EVAL
    est = tf.contrib.tpu.TPUEstimator(model_fn=model_fn,
                                      use_tpu=FLAGS.use_tpu,
                                      config=config,
                                      train_batch_size=FLAGS.batch_size,
                                      eval_batch_size=FLAGS.batch_size)

    # CPU-based estimator used for PREDICT (generating images)
    cpu_est = tf.contrib.tpu.TPUEstimator(model_fn=model_fn,
                                          use_tpu=False,
                                          config=config,
                                          predict_batch_size=_NUM_VIZ_AUDIO)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    tf.logging.info('Starting training for %d steps, current step: %d' %
                    (FLAGS.train_steps, current_step))

    while current_step < FLAGS.train_steps:
        next_checkpoint = min(current_step + FLAGS.train_steps_per_eval,
                              FLAGS.train_steps)
        est.train(input_fn=generate_input_fn(True), max_steps=next_checkpoint)
        current_step = next_checkpoint
        tf.logging.info('Finished training step %d' % current_step)

        if FLAGS.eval_loss:
            # Evaluate loss on test set
            metrics = est.evaluate(input_fn=generate_input_fn(False),
                                   steps=dataset.NUM_EVAL_IMAGES //
                                   FLAGS.batch_size)
            tf.logging.info('Finished evaluating')
            tf.logging.info(metrics)
Exemple #13
0
def train_and_eval(deeplab_estimator, train_dataset, eval_dataset,
                   num_batches_per_epoch):
    """Interleaves training and evaluation."""
    # pylint: disable=protected-access
    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.' %
                    (FLAGS.train_steps,
                     FLAGS.train_steps / num_batches_per_epoch,
                     current_step))
    start_timestamp = time.time()
    while current_step < FLAGS.train_steps:
        # Train for up to steps_per_eval number of steps. At the end of training,
        # a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              FLAGS.train_steps)

        # Data pipeline - input function with train_dataset
        # ----------------------------------------------------------------------
        train_input_fn = data_pipeline.InputReader(
            train_dataset,
            FLAGS.train_split,
            is_training=True,
            model_variant=FLAGS.model_variant
        )

        # Train with estimator
        # ----------------------------------------------------------------------
        deeplab_estimator.train(
            input_fn=train_input_fn,
            max_steps=next_checkpoint
        )
        current_step = next_checkpoint

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.' %
                        (current_step, elapsed_time))

        tf.logging.info('Starting to evaluate.')

        # Data pipeline - input function with eval_dataset
        # ----------------------------------------------------------------------
        eval_input_fn = data_pipeline.InputReader(
            eval_dataset,
            FLAGS.eval_split,
            is_training=False,
            model_variant=FLAGS.model_variant
        )

        # Evaluate with estimator
        # ----------------------------------------------------------------------
        eval_results = deeplab_estimator.evaluate(
            input_fn=eval_input_fn,
            steps=eval_dataset.num_samples // FLAGS.eval_batch_size
        )
        tf.logging.info('Eval results: %s' % eval_results)
Exemple #14
0
    def train_and_eval(self):
        assert FLAGS.mode == 'train_and_eval'

        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)

        train_epochs = self.params['train_steps'] / \
            self.params['steps_per_epoch']
        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current step %d.',
            self.params['train_steps'], train_epochs, current_step)

        # neptune.log_text('Train INFO', f"Training for {self.params['train_steps']} steps {train_epochs} epochs in total)\n Current step {current_step}")

        start_timestamp = time.time(
        )  # This time will include compilation time

        eval_results = None
        while current_step < self.params['train_steps']:
            # Train for up to steps_per_eval number of steps.
            # At the end of training, a checkpoint will be written to --model_dir.
            steps_per_eval = int(FLAGS.epochs_per_eval *
                                 self.params['steps_per_epoch'])
            next_eval = (current_step // steps_per_eval) * \
                steps_per_eval + steps_per_eval
            print("next eval point : ", next_eval)
            next_checkpoint = min(next_eval, self.params['train_steps'])
            self.est.train(input_fn=self.imagenet_train.input_fn,
                           max_steps=int(next_checkpoint))
            current_step = next_checkpoint

            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                next_checkpoint, int(time.time() - start_timestamp))
            neptune.log_text(
                'train INFO',
                'Finished training up to step {}. Elapsed seconds {}'.format(
                    next_checkpoint, int(time.time() - start_timestamp)))

            eval_results = self.eval()

        if eval_results is None:
            eval_results = self.eval()

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                        self.params['train_steps'], elapsed_time)
        neptune.log_text(
            'train INFO',
            'Finished training up to step {} . Elapsed seconds {}'.format(
                self.params["train_steps"], elapsed_time))

        tf.keras.backend.clear_session()
        tf.reset_default_graph()

        return eval_results['top_1_accuracy'].item()
Exemple #15
0
  def after_run(self, run_context, run_values):
    global_step = run_values.results + 1
    if global_step >= self._last_step:
      # Check latest global step in the checkpoint to ensure that the targeted
      # last step is written on disk.

      step = estimator_lib._load_global_step_from_checkpoint_dir(
          self._model_dir)
      if step >= self._last_step:
        run_context.request_stop()
      else:
        time.sleep(self._wait_after_file_check_secs)
Exemple #16
0
  def after_run(self, run_context, run_values):
    global_step = run_values.results + 1
    if global_step >= self._last_step:
      # Check latest global step in the checkpoint to ensure that the targeted
      # last step is written on disk.

      step = estimator_lib._load_global_step_from_checkpoint_dir(
          self._model_dir)
      if step >= self._last_step:
        run_context.request_stop()
      else:
        time.sleep(self._wait_after_file_check_secs)
Exemple #17
0
def main(argv):
    del argv
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(
            num_shards=FLAGS.num_shards,
            iterations_per_loop=FLAGS.iterations_per_loop))

    # Set module-level global variable so that model_fn and input_fn can be
    # identical for each different kind of dataset and model
    global dataset, model
    dataset = bias_input
    model = bias_model

    # TPU-based estimator used for TRAIN and EVAL
    est = tf.contrib.tpu.TPUEstimator(model_fn=model_fn,
                                      use_tpu=FLAGS.use_tpu,
                                      config=config,
                                      train_batch_size=FLAGS.batch_size,
                                      eval_batch_size=FLAGS.batch_size)

    # CPU-based estimator used for PREDICT (generating images)
    cpu_est = tf.contrib.tpu.TPUEstimator(model_fn=model_fn,
                                          use_tpu=False,
                                          config=config,
                                          predict_batch_size=_NUM_VIZ_AUDIO)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    tf.logging.info('Starting training for %d steps, current step: %d' %
                    (FLAGS.train_steps, current_step))

    # Render some generated images
    G_z = cpu_est.predict(input_fn=noise_input_fn)
    G_z = [p['generated_audio'][:, :] for p in G_z]
    G_z = np.array(G_z)
    preview_dir = './preview'
    if not os.path.isdir(preview_dir):
        os.makedirs(preview_dir)

    for i in range(len(G_z)):
        audio = np.int16(G_z[i] / np.max(np.abs(G_z[i])) * 32767)
        preview_fp = os.path.join(
            preview_dir, '{}_{}_{}.wav'.format(str(i % 10), str(current_step),
                                               str(i)))
        wavwrite(preview_fp, _FS, audio)

    tf.logging.info('Finished generating images')
    def train(self, train_steps, generate_input_fn):
        """Train the model

        Args:
            train_steps (int): Numer of training steps
            generate_input_fn (function): Function that resturns input_fn
            function. (see example or tf.Estimator documentation)
            noise_input_fn (function): input_fn that returns a noise vector
        """

        current_step = estimator._load_global_step_from_checkpoint_dir(
            self.model_dir)  # pylint: disable=protected-access,line-too-long
        tf.logging.info('Starting training for %d steps, current step: %d' %
                        (train_steps, current_step))
        tf.gfile.MakeDirs(os.path.join(self.model_dir, 'generated_images'))

        # self.generate_images(generate_input_fn, current_step)

        while current_step < train_steps:
            next_checkpoint = int(
                min(current_step + self.train_steps_per_eval, train_steps))
            tf.logging.info('Step: %s  -- (Next checkpoint %s)', current_step,
                            next_checkpoint)
            self.est.train(input_fn=generate_input_fn('TRAIN'),
                           max_steps=next_checkpoint)
            current_step = next_checkpoint
            tf.logging.info('Finished training step %d' % current_step)

            if self.eval_loss:
                # Evaluate loss on test set
                metrics = self.est.evaluate(
                    input_fn=generate_input_fn('EVAL'),
                    steps=max(self.num_eval_images // self.batch_size, 1))
                tf.logging.info('Finished evaluating')
                tf.logging.info(metrics)

            self.generate_images(generate_input_fn, current_step)
            gc.collect(
            )  # I'm experiencing some kind of memory leak (and seems that other people
Exemple #19
0
def main(unused_argv):

    # Check flag conditions:
    if FLAGS.mode == 'train':
        tf.logging.info('Mode = train, TPU = %s, Num cores = %d' %
                        (FLAGS.tpu, FLAGS.train_num_cores))

    elif FLAGS.mode == 'evaluate':
        tf.logging.info('Mode = evaluate, TPU = %s, Num cores = %d' %
                        (FLAGS.eval_tpu, FLAGS.eval_num_cores))

    elif FLAGS.mode == 'train_and_eval':
        if FLAGS.train_num_cores > 8:
            tf.logging.info('Mode = train_and_eval, Train TPU = %s, '
                            'Train num cores: %d, Eval TPU = %s, '
                            'Eval num cores: %d' %
                            (FLAGS.tpu, FLAGS.train_num_cores, FLAGS.eval_tpu,
                             FLAGS.eval_num_cores))
        else:
            tf.logging.info('Mode = train_and_eval, TPU = %s, '
                            'Num cores: %d' %
                            (FLAGS.tpu, FLAGS.train_num_cores))

    # Set up general purpose tpu_cluster_resolver based on FLAGS.mode:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu
        if FLAGS.mode in ['train', 'train_and_eval'] else FLAGS.eval_tpu,
        zone=FLAGS.tpu_zone
        if FLAGS.mode in ['train', 'train_and_eval'] else FLAGS.eval_tpu_zone,
        project=FLAGS.gcp_project)

    # For mode == 'train_and_eval' we can have 2 options:
    # 1. Use same TPU for training and evaluating (only v2-8)
    # 2. Use TPU with more cores for training (v2-32/128/256/512),
    #       and a separate v2-8 for evaluating.
    if FLAGS.mode == 'train_and_eval' and FLAGS.train_num_cores > 8:
        eval_tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.eval_tpu,
            zone=FLAGS.eval_tpu_zone,
            project=FLAGS.gcp_project)

    if FLAGS.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)

    ##### RunConfig parameters:
    '''Arguments:
        iterations_per_loop: number of training steps running in TPU system
            before returning to CPU host for each Session.run. Global step is
            increased iterations_per_loop times in one Session.run. It is recommended
            to be set as number of global steps for next checkpoint.
        per_host_input_for_training: If True, input_fn is invoked once on each host.
            If PER_HOST_V1: batch size per shard = train_batch_size // #hosts (#cpus)
            If PER_HOST_V2: batch size per shard = train_batch_size // #cores  
        keep_checkpoint_max: If None, keep all checkpoint files, otherwise specify
            'n' to keep latest 'n' files.

    Each TPU device has 8 cores and is connected to a host (CPU). Larger slices have
    multiple hosts. For instance, v2-256 communicates with 16 hosts. So, per_host_input_\
    for_training will invoke/create the Dataset pipeline 16 times in total for 16 hosts,
    where each host will serve 256/16 = 16 cores. Each core will take a batch size represented
    by flag PER_HOST_V2. This functionality is missing right now in tf.Keras which makes it
    difficult to scale up models to bigger TPU slices.

    '''
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        keep_checkpoint_max=None,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.train_num_cores
                if FLAGS.mode in ['train', 'train_and_eval']
                else FLAGS.eval_num_cores,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.\
                PER_HOST_V2))

    if FLAGS.mode == 'train_and_eval' and FLAGS.train_num_cores > 8:
        config_eval = tf.contrib.tpu.RunConfig(
            cluster=tpu_cluster_resolver_eval,
            model_dir=FLAGS.model_dir,
            save_checkpoints_steps=save_checkpoints_steps,
            log_step_count_steps=FLAGS.log_step_count_steps,
            keep_checkpoint_max=None,
            session_config=tf.ConfigProto(
                graph_options=tf.GraphOptions(
                    rewrite_options=rewriter_config_pb2.RewriterConfig(
                        disable_meta_optimizer=True))),
            tpu_config=tf.contrib.tpu.TPUConfig(
                iterations_per_loop=FLAGS.iterations_per_loop,
                num_shards=FLAGS.eval_num_cores,
                per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.\
                    PER_HOST_V2))

    ##### Estimator story:
    '''Estimator handles running details, such as replicating inputs and models for
        core, and returning to host periodically to run hooks.
        -> TPUEstimator transforms a global batch size in params to a per-shard/core
            batch size when calling input_fn and model_fn. Users SHOULD specify GLOBAL
            batch size in constructor and then get the batch size for EACH shard/core 
            in input_fn and model_fn by PARAMS['BATCH_SIZE'].
        -> For training, model_fn gets per_core_batch_size; input_fn may get
            per-core or per-host batch size depending on per_host_input_for_training in
            TPUConfig. For this model, we use PER_HOST_V2.
        -> For evaluation and prediction, model_fn gets per-core batch size and input_fn
            per-host batch size.

        Current limitations:
            -> TPU prediction only works on a single host (one TPU worker)
            -> input_fn must return a Dataset instance rather than features. In fact,
                train(), and evaluate() also support Dataset as return value.
    '''
    '''Arguments:
        model_fn: Should be a TPUEstimatorSpec. 
        use_tpu: Setting to False for testing. All training, evaluation, and predict will
            be executed on CPU. input_fn and model_fn will receive train_batch_size or
            eval_batch_size unmodified as params['batch_size']. Setting to True, input_fn
            and model_fn will receive per_core batch size. :config plays a role in specifying
            details about TPU workers to the Estimator.
        config: An tpu_config.RunConfig configuration object. Cannot be None.
        params: An optional dict of hyper parameters that will be passed into input_fn and
            model_fn. Keys are names of parameters, values are basic python types. There are
            reserved keys for TPUEstimator, including 'batch_size'. Extra parameters can be 
            added to this dictionary and can be used in input_fn and model_fn scripts.
        train_batch_size: An int representing the global batch size. TPUEstimator transforms
            this global batch size to a per-shard/core batch size, as params['batch_size'],
            when calling input_fn and model_fn. Cannot be None if :use_tpu is True. Must be
            DIVISIBLE by total number of replicas. The per-shard batch size calculation is
            automatically done using TPUConfig details.
        export_to_tpu: If True, export_savedmodel() exports a metagraph for serving on TPU
            besides the one on CPU.
    '''

    if not FLAGS.init_checkpoint == 'None':
        warm_start_vars = FLAGS.warm_start_vars.split(',')
        warm_start_vars = [x.strip() for x in warm_start_vars]
        ws = tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=FLAGS.init_checkpoint,
            vars_to_warm_start=warm_start_vars)

        i3d_classifier = tf.contrib.tpu.TPUEstimator(
            use_tpu=FLAGS.use_tpu,
            model_fn=i3d_model_fn,
            config=config,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            predict_batch_size=FLAGS.predict_batch_size,
            export_to_tpu=FLAGS.export_to_tpu,
            warm_start_from=ws)
    else:
        i3d_classifier = tf.contrib.tpu.TPUEstimator(
            use_tpu=FLAGS.use_tpu,
            model_fn=i3d_model_fn,
            config=config,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            predict_batch_size=FLAGS.predict_batch_size,
            export_to_tpu=FLAGS.export_to_tpu)

    if FLAGS.mode == 'train_and_eval' and FLAGS.train_num_cores > 8:
        i3d_eval = tf.contrib.tpu.TPUEstimator(
            use_tpu=FLAGS.use_tpu,
            model_fn=i3d_model_fn,
            config=config_eval,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            export_to_tpu=FLAGS.export_to_tpu,
            warm_start_from=ws)

    assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
        'Invalid value for --precision flag; must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', FLAGS.precision)

    use_bfloat16 = FLAGS.precision == 'bfloat16'

    tf.logging.info('Using dataset: %s', FLAGS.data_dir)

    list_of_augmentations = [
        'random_crop', 'random_brightness', 'random_contrast'
    ]

    # dataset_train and dataset_eval are the Input pipelines
    dataset_train, dataset_eval, dataset_predict = [
        inp_pipeline.InputPipelineTFExample(
            data_dir=FLAGS.data_dir,
            is_training=is_training,
            cache=FLAGS.use_cache and is_training,
            use_bfloat16=use_bfloat16,
            target_image_size=224,
            num_frames=32,  # num_frames_change_here
            num_classes=15,
            num_parallel_calls=FLAGS.num_parallel_calls,
            list_of_augmentations=list_of_augmentations)
        for is_training in [True, False, False]
    ]

    # num_train_videos = total images in the dataset
    # train_batch_size = total batch size (across all cores)
    steps_per_epoch = FLAGS.num_train_videos // FLAGS.train_batch_size
    eval_steps = FLAGS.num_eval_videos // FLAGS.eval_batch_size

    if FLAGS.mode == 'train' or FLAGS.mode == 'evaluate':

        # Automatically get the latest checkpoint file and latest
        # train step from the model_dir.
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            'step %d.', FLAGS.train_steps, FLAGS.train_steps / steps_per_epoch,
            current_step)

        start_timestamp = time.time()  # Compilation time included

        if FLAGS.mode == 'train':
            hooks = []

            # Not sure what this does. I think this takes care of
            # asynchronously saving checkpoint files, irrespective of
            # training routine on TPU.
            if FLAGS.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, FLAGS.iterations_per_loop)))

            # Number of steps between collecting prog=files if larger
            # than 0.
            if FLAGS.profile_every_n_steps > 0:
                hooks.append(
                    tpu_profiler_hook.TPUProfilerHook(
                        save_steps=FLAGS.profile_every_n_steps,
                        output_dir=FLAGS.model_dir,
                        tpu=FLAGS.tpu))

            ##### Estimator training story:
            '''Arguments:
                input_fn: Returns mini batches for training. Function should
                    return tf.data.Dataset object: tuple (features, labels).
                    Both features and labels are consumed by model_fn. They
                    should satisfy the expectation of model_fn for inputs.
                hooks: List of tf.train.SessionRunHook subclass instance. Used
                    for callbacks inside the training loop.
                max_steps: Number of total steps for which to train the model.
            '''
            i3d_classifier.train(input_fn=dataset_train.input_fn,
                                 max_steps=FLAGS.train_steps,
                                 hooks=hooks)

        elif FLAGS.mode == 'evaluate':
            '''
            for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
                tf.logging.info(
                    'Starting to evaluate using %s',
                    ckpt)
            '''
            f = open(
                'evaluations/dummy_' + FLAGS.model_dir.split('/')[-1] + '.txt',
                'ab')
            #ids = [i for i in range(12600, 14000, 300)]
            #ids.append(14000)
            ids = [14000]
            #import ipdb; ipdb.set_trace()
            for i in ids:
                try:
                    ckpt = FLAGS.model_dir + '/model.ckpt-' + str(i)
                    start_timestamp = time.time()  # Compilation time included
                    eval_results = i3d_classifier.evaluate(
                        input_fn=dataset_eval.input_fn,
                        steps=eval_steps,
                        checkpoint_path=ckpt)
                    elapsed_time = int(time.time() - start_timestamp)
                    tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                    eval_results, elapsed_time)

                    f.write('step: ' + str(i) + ', stats: ' +
                            str(eval_results) + '\n')
                    f.close()
                    f = open(
                        'evaluations/dummy_' + FLAGS.model_dir.split('/')[-1] +
                        '.txt', 'ab')

                    # Terminate eval job when final checkpoint is reached
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                    if current_step >= FLAGS.train_steps:
                        tf.logging.info(
                            'Evaluation finished after training step %d',
                            current_step)
                        break

                except tf.errors.NotFoundError:
                    tf.logging.info(
                        'Checkpoint %s no longer exists, skipping checkpoint',
                        ckpt)
            f.close()

    elif FLAGS.mode == 'predict':
        i = 1000
        ckpt = FLAGS.model_dir + '/model.ckpt-' + str(i)
        predict_iters = i3d_classifier.predict(
            input_fn=dataset_predict.input_fn,
            checkpoint_path=ckpt,
            yield_single_examples=False)
        all_gt, all_preds = [], []
        count = 0
        for predict_result in predict_iters:
            gt = predict_result['ground_truth']
            preds = predict_result['predictions']
            if count % 10 == 0:
                print('step:{}, shapes:{}'.format(count, gt.shape))
            count += 1

            for j in gt:
                all_gt.append(j)
                all_preds.append(j)

        print('Finished, {}'.format(len(all_gt)))
        with open('gt.pkl', 'wb') as handle:
            pickle.dump(all_gt, handle)
        with open('preds.pkl', 'wb') as handle:
            pickle.dump(all_preds, handle)
Exemple #20
0
def main(unused_argv):
    tpu_grpc_url = None
    tpu_cluster_resolver = None
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if not FLAGS.master and not FLAGS.tpu_name:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master:
            if FLAGS.tpu_name:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        cluster=tpu_cluster_resolver,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data_dir)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data_dir)

    if FLAGS.mode == 'eval':
        eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(FLAGS.model_dir):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                                (eval_results, elapsed_time))

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                       batches_per_epoch, current_step))

        start_timestamp = time.time(
        )  # This time will include compilation time
        if FLAGS.mode == 'train':
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=FLAGS.train_steps)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=next_checkpoint)
                current_step = next_checkpoint

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be consistently excluded modulo the batch size.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
                tf.logging.info('Eval results: %s' % eval_results)

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info(
            'Finished training up to step %d. Elapsed seconds %d.' %
            (FLAGS.train_steps, elapsed_time))

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            resnet_classifier.export_savedmodel(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
Exemple #21
0
def main(unused_argv):
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data_dir)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data_dir)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                   batches_per_epoch, current_step))
    start_timestamp = time.time()
    while current_step < FLAGS.train_steps:
        # Train for up to steps_per_eval number of steps. At the end of training, a
        # checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              FLAGS.train_steps)
        resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                max_steps=next_checkpoint)
        current_step = next_checkpoint

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info(
            'Finished training up to step %d. Elapsed seconds %d.' %
            (current_step, elapsed_time))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images may
        # be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
        tf.logging.info('Eval results: %s' % eval_results)

    if FLAGS.export_dir is not None:
        # The guide to serve a exported TensorFlow model is at:
        #    https://www.tensorflow.org/serving/serving_basic
        tf.logging.info('Starting to export model.')
        resnet_classifier.export_savedmodel(
            export_dir_base=FLAGS.export_dir,
            serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
Exemple #22
0
def main(unused_argv):
    tpu_grpc_url = None
    tpu_cluster_resolver = None
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if not FLAGS.master and not FLAGS.tpu_name:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master:
            if FLAGS.tpu_name:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.iterations_per_loop,
        keep_checkpoint_max=None,
        cluster=tpu_cluster_resolver,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_main.resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data_dir,
        transpose_input=FLAGS.transpose_input)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data_dir,
        transpose_input=FLAGS.transpose_input)

    if FLAGS.mode == 'train':
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = resnet_main.NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                       batches_per_epoch, current_step))

        start_timestamp = time.time(
        )  # This time will include compilation time

        # Write a dummy file at the start of training so that we can measure the
        # runtime at each checkpoint from the file write time.
        tf.gfile.MkDir(FLAGS.model_dir)
        if not tf.gfile.Exists(os.path.join(FLAGS.model_dir, 'START')):
            with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'START'),
                                'w') as f:
                f.write(str(start_timestamp))

        resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                max_steps=FLAGS.train_steps)

    else:
        assert FLAGS.mode == 'eval'

        start_timestamp = tf.gfile.Stat(os.path.join(FLAGS.model_dir,
                                                     'START')).mtime_nsec
        results = []
        eval_steps = resnet_main.NUM_EVAL_IMAGES // FLAGS.eval_batch_size

        ckpt_steps = set()
        all_files = tf.gfile.ListDirectory(FLAGS.model_dir)
        for f in all_files:
            mat = re.match(CKPT_PATTERN, f)
            if mat is not None:
                ckpt_steps.add(int(mat.group('gs')))
        ckpt_steps = sorted(list(ckpt_steps))
        tf.logging.info('Steps to be evaluated: %s' % str(ckpt_steps))

        for step in ckpt_steps:
            ckpt = os.path.join(FLAGS.model_dir, 'model.ckpt-%d' % step)

            batches_per_epoch = resnet_main.NUM_TRAIN_IMAGES // FLAGS.train_batch_size
            current_epoch = step // batches_per_epoch

            end_timestamp = tf.gfile.Stat(ckpt + '.index').mtime_nsec
            elapsed_hours = (end_timestamp - start_timestamp) / (1e9 * 3600.0)

            tf.logging.info('Starting to evaluate.')
            eval_start = time.time()  # This time will include compilation time
            eval_results = resnet_classifier.evaluate(
                input_fn=imagenet_eval.input_fn,
                steps=eval_steps,
                checkpoint_path=ckpt)
            eval_time = int(time.time() - eval_start)
            tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                            (eval_results, eval_time))
            results.append([
                current_epoch,
                elapsed_hours,
                '%.2f' % (eval_results['top_1_accuracy'] * 100),
                '%.2f' % (eval_results['top_5_accuracy'] * 100),
            ])

            time.sleep(60)

        with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.tsv'), 'wb') as tsv_file:  # pylint: disable=line-too-long
            writer = csv.writer(tsv_file, delimiter='\t')
            writer.writerow(['epoch', 'hours', 'top1Accuracy', 'top5Accuracy'])
            writer.writerows(results)
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  # RevNet specific configuration
  config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)

  if FLAGS.use_tpu:
    tf.logging.info("Using TPU.")
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
  else:
    tpu_cluster_resolver = None

  # TPU specific configuration
  tpu_config = tf.contrib.tpu.TPUConfig(
      # Recommended to be set as number of global steps for next checkpoint
      iterations_per_loop=FLAGS.iterations_per_loop,
      num_shards=FLAGS.num_shards)

  # Estimator specific configuration
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False),
      tpu_config=tpu_config,
  )

  # Construct TPU Estimator
  estimator = tf.contrib.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=config.tpu_batch_size,
      eval_batch_size=config.tpu_eval_batch_size,
      config=run_config,
      params={"config": config})

  # Construct input functions
  train_input_fn = get_input_fn(
      config=config, data_dir=FLAGS.data_dir, split="train_all")
  eval_input_fn = get_input_fn(
      config=config, data_dir=FLAGS.data_dir, split="test")

  # Disabling a range within an else block currently doesn't work
  # due to https://github.com/PyCQA/pylint/issues/872
  # pylint: disable=protected-access
  if FLAGS.mode == "eval":
    # TPUEstimator.evaluate *requires* a steps argument.
    # Note that the number of examples used during evaluation is
    # --eval_steps * --batch_size.
    # So if you change --batch_size then change --eval_steps too.
    eval_steps = 10000 // config.tpu_eval_batch_size

    # Run evaluation when there's a new checkpoint
    for ckpt in evaluation.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info("Starting to evaluate.")
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = estimator.evaluate(
            input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info("Eval results: %s. Elapsed seconds: %d" %
                        (eval_results, elapsed_time))

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split("-")[1])
        if current_step >= config.max_train_iter:
          tf.logging.info(
              "Evaluation finished after training step %d" % current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info(
            "Checkpoint %s no longer exists, skipping checkpoint" % ckpt)

  else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = estimator_._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)
    tf.logging.info("Training for %d steps . Current"
                    " step %d." % (config.max_train_iter, current_step))

    start_timestamp = time.time()  # This time will include compilation time
    if FLAGS.mode == "train":
      estimator.train(input_fn=train_input_fn, max_steps=config.max_train_iter)
    else:
      eval_steps = 10000 // config.tpu_eval_batch_size
      assert FLAGS.mode == "train_and_eval"
      while current_step < config.max_train_iter:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              config.max_train_iter)
        estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be consistently excluded modulo the batch size.
        tf.logging.info("Starting to evaluate.")
        eval_results = estimator.evaluate(
            input_fn=eval_input_fn, steps=eval_steps)
        tf.logging.info("Eval results: %s" % eval_results)

    elapsed_time = int(time.time() - start_timestamp)
    tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
                    (config.max_train_iter, elapsed_time))
Exemple #24
0
 def test_run_train_op_and_saves_at_the_end(self):
   est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
   est.fit(dummy_input_fn, steps=5)
   self.assertEqual(
       5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
Exemple #25
0
def experiment(model_config):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info("SCRIPT START")

    tf.logging.info("TPU resolver started")

    tpu_cluster_resolver = TPUClusterResolver(
        tpu=os.environ['TPU_NAME'],
        project=os.environ['PROJECT_NAME'],
        zone=os.environ['PROJECT_ZONE'])
    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=model_config['model_base_dir'] + os.path.sep + str(model_config["experiment_id"]),
        save_checkpoints_steps=500,
        save_summary_steps=250,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=500,
            num_shards=8,
            per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V1))  # pylint: disable=line-too-long

    tf.logging.info("Creating datasets")
    urmp_train, urmp_eval, urmp_test = [
        urmp_input.URMPInput(mode=mode,
                             data_dir=model_config['data_path'],
                             transpose_input=False,
                             use_bfloat16=model_config['use_bfloat16'])
        for mode in ['train', 'eval', 'test']
    ]

    tf.logging.info("Assigning TPUEstimator")
    # Optimize in a +supervised fashion until validation loss worsens
    separator = tpu_estimator.TPUEstimator(
        use_tpu=model_config["use_tpu"],
        model_fn=unet_separator,
        config=config,
        train_batch_size=model_config['batch_size'],
        eval_batch_size=model_config['batch_size'],
        predict_batch_size=model_config['batch_size'],
        params={
            i: model_config[i]
            for i in model_config if (i != 'batch_size' and i != 'context')
        })

    if model_config['load_model']:
        tf.logging.info("Load the model")
        current_step = estimator._load_global_step_from_checkpoint_dir(
            model_config['model_base_dir'] + os.path.sep +
            str(model_config["experiment_id"]))

    if model_config['mode'] == 'train_and_eval':
        tf.logging.info("Train the model")
        # Should be an early stopping here, but it will come with tf 1.10
        separator.train(input_fn=urmp_train.input_fn,
                        steps=model_config['training_steps'])

        tf.logging.info("Supervised training finished!")
        tf.logging.info("Evaluate model")
        # Evaluate the model.
        eval_result = separator.evaluate(
            input_fn=urmp_eval.input_fn,
            steps=model_config['evaluation_steps'])
        tf.logging.info('Evaluation results: %s' % eval_result)

    elif model_config['mode'] == 'predict':
        tf.logging.info("Test results and save predicted sources:")
        predictions = separator.predict(input_fn=urmp_test.input_fn)

        for prediction in predictions:
            Test.save_prediction(prediction,
                                 estimates_path=model_config["estimates_path"],
                                 sample_rate=model_config["expected_sr"])
        Utils.concat_and_upload(
            model_config["estimates_path"], model_config['model_base_dir'] +
            os.path.sep + str(model_config["experiment_id"]))
Exemple #26
0
def main(unused_argv):
    params = resnet_params.from_file(FLAGS.param_file)
    params = resnet_params.override(params, FLAGS.param_overrides)
    resnet_params.log_hparams_to_model_dir(params, FLAGS.model_dir)
    tf.logging.info('Model params: {}'.format(params))

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu if (FLAGS.tpu or params['use_tpu']) else '',
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)

    if params['use_async_checkpointing']:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, params['iterations_per_loop'])
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=params['iterations_per_loop'],
            num_shards=params['num_cores'],
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    if FLAGS.inference_with_all_cores:
        resnet_classifier = tf.contrib.tpu.TPUEstimator(
            use_tpu=params['use_tpu'],
            model_fn=resnet_model_fn,
            config=config,
            params=params,
            train_batch_size=params['train_batch_size'],
            eval_batch_size=params['eval_batch_size'],
            export_to_tpu=FLAGS.export_to_tpu,
            experimental_exported_model_uses_all_cores=FLAGS.
            inference_with_all_cores)
    else:
        resnet_classifier = tf.contrib.tpu.TPUEstimator(
            use_tpu=params['use_tpu'],
            model_fn=resnet_model_fn,
            config=config,
            params=params,
            train_batch_size=params['train_batch_size'],
            eval_batch_size=params['eval_batch_size'],
            export_to_tpu=FLAGS.export_to_tpu)
    assert (params['precision'] == 'bfloat16' or params['precision']
            == 'float32'), ('Invalid value for precision parameter; '
                            'must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', params['precision'])
    use_bfloat16 = params['precision'] == 'bfloat16'

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=use_bfloat16,
                transpose_input=params['transpose_input'],
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=params['transpose_input'],
                cache=params['use_cache'] and is_training,
                image_size=params['image_size'],
                num_parallel_calls=params['num_parallel_calls'],
                use_bfloat16=use_bfloat16) for is_training in [True, False]
        ]

    steps_per_epoch = params['num_train_images'] // params['train_batch_size']
    eval_steps = params['num_eval_images'] // params['eval_batch_size']

    if FLAGS.mode == 'eval':

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= params['train_steps']:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        steps_per_epoch = params['num_train_images'] // params[
            'train_batch_size']
        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', params['train_steps'],
            params['train_steps'] / steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if params['use_async_checkpointing']:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, params['iterations_per_loop'])))
            if FLAGS.profile_every_n_steps > 0:
                hooks.append(
                    tpu_profiler_hook.TPUProfilerHook(
                        save_steps=FLAGS.profile_every_n_steps,
                        output_dir=FLAGS.model_dir,
                        tpu=FLAGS.tpu))
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=params['train_steps'],
                                    hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < params['train_steps']:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      params['train_steps'])
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=params['num_eval_images'] //
                    params['eval_batch_size'])
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                params['train_steps'], elapsed_time)

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            export_path = resnet_classifier.export_saved_model(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
            if FLAGS.add_warmup_requests:
                inference_warmup.write_warmup_requests(
                    export_path,
                    FLAGS.model_name,
                    params['image_size'],
                    batch_sizes=FLAGS.inference_batch_sizes,
                    image_format='JPEG')
Exemple #27
0
def main(unused_argv):
    tf.set_random_seed(FLAGS.random_seed)

    save_checkpoints_steps = 100
    run_config_args = {
        'model_dir': FLAGS.model_dir,
        'save_checkpoints_steps': save_checkpoints_steps,
        'log_step_count_steps': FLAGS.log_step_count_steps,
        'keep_checkpoint_max': 200,
    }

    config = tf_estimator.RunConfig(**run_config_args)

    if FLAGS.warm_start_ckpt_path:
        var_names = []
        checkpoint_path = FLAGS.warm_start_ckpt_path
        reader = tf.train.NewCheckpointReader(checkpoint_path)
        for key in reader.get_variable_to_shape_map():
            keep_str = 'Momentum|global_step|finetune_global_step|Adam|final_dense_dst'
            if not re.findall('({})'.format(keep_str, ), key):
                var_names.append(key)

        tf.logging.info('Warm-starting tensors: %s', sorted(var_names))

        vars_to_warm_start = var_names
        warm_start_settings = tf_estimator.WarmStartSettings(
            ckpt_to_initialize_from=checkpoint_path,
            vars_to_warm_start=vars_to_warm_start)
    else:
        warm_start_settings = None

    classifier = tf_estimator.Estimator(get_model_fn(),
                                        config=config,
                                        warm_start_from=warm_start_settings)

    def _merge_datasets(train_batch):
        feature, label = train_batch['image'], train_batch['label'],
        features = {
            'feature': feature,
        }
        labels = {
            'label': label,
        }
        return (features, labels)

    def get_dataset(dataset_split):
        """Returns dataset creation function."""
        def make_input_dataset():
            """Returns input dataset."""
            train_data = tfds.load(name=FLAGS.target_dataset,
                                   split=dataset_split)
            train_data = train_data.shuffle(1024).repeat().batch(
                FLAGS.train_batch_size)
            dataset = tf.data.Dataset.zip((train_data, ))
            dataset = dataset.map(_merge_datasets)
            dataset = dataset.prefetch(
                buffer_size=tf.data.experimental.AUTOTUNE)
            return dataset

        return make_input_dataset

    # pylint: disable=protected-access
    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)

    train_steps = FLAGS.train_steps
    while current_step < train_steps:
        print('Run {}'.format(current_step))
        next_checkpoint = current_step + 500
        classifier.train(input_fn=get_dataset('train'),
                         max_steps=next_checkpoint)
        current_step = next_checkpoint
def main(unused_argv):
  # [START tpu-cluster-revolver]
  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu,
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project)

  config = tpu_config.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=max(600, FLAGS.iterations_per_loop),
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_cores,
          per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long
  # [END tpu-cluster-revolver]

  resnet_classifier = tpu_estimator.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=resnet_model_fn,
      config=config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size)

  assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
      'Invalid value for --precision flag; must be bfloat16 or float32.')
  tf.logging.info('Precision: %s', FLAGS.precision)
  use_bfloat16 = FLAGS.precision == 'bfloat16'

  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  imagenet_train, imagenet_eval = [imagenet_input.ImageNetInput(
      is_training=is_training,
      data_dir=FLAGS.data_dir,
      transpose_input=FLAGS.transpose_input,
      use_bfloat16=use_bfloat16) for is_training in [True, False]]

  if FLAGS.mode == 'eval':
    eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

    # Run evaluation when there's a new checkpoint
    for ckpt in evaluation.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info('Starting to evaluate.')
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                        (eval_results, elapsed_time))

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split('-')[1])
        if current_step >= FLAGS.train_steps:
          tf.logging.info(
              'Evaluation finished after training step %d' % current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info(
            'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)

  else:   # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.' % (FLAGS.train_steps,
                                   FLAGS.train_steps / batches_per_epoch,
                                   current_step))

    start_timestamp = time.time()  # This time will include compilation time
    if FLAGS.mode == 'train':
      resnet_classifier.train(
          input_fn=imagenet_train.input_fn, max_steps=FLAGS.train_steps)

    else:
      assert FLAGS.mode == 'train_and_eval'
      while current_step < FLAGS.train_steps:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              FLAGS.train_steps)
        resnet_classifier.train(
            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be consistently excluded modulo the batch size.
        tf.logging.info('Starting to evaluate.')
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
        tf.logging.info('Eval results: %s' % eval_results)

    elapsed_time = int(time.time() - start_timestamp)
    tf.logging.info('Finished training up to step %d. Elapsed seconds %d.' %
                    (FLAGS.train_steps, elapsed_time))

    if FLAGS.export_dir is not None:
      # The guide to serve a exported TensorFlow model is at:
      #    https://www.tensorflow.org/serving/serving_basic
      tf.logging.info('Starting to export model.')
      resnet_classifier.export_savedmodel(
          export_dir_base=FLAGS.export_dir,
          serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
Exemple #29
0
def main(argv):

  del argv

  tpu_cluster_resolver = None
  
  if FLAGS.use_tpu:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu,
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project)

  config = tpu_config.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      tpu_config=tpu_config.TPUConfig(
          num_shards=FLAGS.num_shards,
          iterations_per_loop=FLAGS.iterations_per_loop))

  # Set module-level global variable so that model_fn and input_fn can be
  # identical for each different kind of dataset and model
  global dataset, model  
  dataset = celeba_input
  model = celeba_model

  # TPU-based estimator used for TRAIN and EVAL
  est = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      config=config,
      params={"data_dir": FLAGS.data_dir},
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size)

  # CPU-based estimator used for PREDICT (generating images)
  cpu_est = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      use_tpu=False,
      config=config,
      params={"data_dir": FLAGS.data_dir},
      predict_batch_size=_NUM_VIZ_IMAGES)

  tf.gfile.MakeDirs(os.path.join(FLAGS.model_dir))
  tf.gfile.MakeDirs(os.path.join(FLAGS.model_dir, 'generated_images'))

  current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)   # pylint: disable=protected-access,line-too-long
  tf.logging.info('Starting training for %d steps, current step: %d' %
                  (FLAGS.train_steps, current_step))
  while current_step < FLAGS.train_steps:
    next_checkpoint = min(current_step + FLAGS.train_steps_per_eval,
                          FLAGS.train_steps)
    est.train(input_fn=generate_input_fn(True),
              max_steps=next_checkpoint)
    current_step = next_checkpoint
    tf.logging.info('Finished training step %d' % current_step)

    if FLAGS.eval_loss:
      # Evaluate loss on test set
      metrics = est.evaluate(input_fn=generate_input_fn(False),
                             steps=dataset.NUM_EVAL_IMAGES // FLAGS.batch_size)
      tf.logging.info('Finished evaluating')
      tf.logging.info(metrics)

    # Render some generated images
    generated_iter = cpu_est.predict(input_fn=noise_input_fn)
    images = [p['generated_images'][:, :, :] for p in generated_iter]
    assert len(images) == _NUM_VIZ_IMAGES
    image_rows = [np.concatenate(images[i:i+10], axis=0)
                  for i in range(0, _NUM_VIZ_IMAGES, 10)]
    tiled_image = np.concatenate(image_rows, axis=1)

    img = dataset.convert_array_to_image(tiled_image)

    step_string = str(current_step).zfill(5)
    file_obj = tf.gfile.Open(
        os.path.join(FLAGS.model_dir,
                     'generated_images', 'gen_%s.png' % (step_string)), 'w')
    img.save(file_obj, format='png')
    tf.logging.info('Finished generating images')
Exemple #30
0
def main(unused_argv):
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=max(600, FLAGS.iterations_per_loop),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

    resnet_classifier = tpu_estimator.TPUEstimator(
        export_to_tpu=False,
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
        'Invalid value for --precision flag; must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', FLAGS.precision)
    use_bfloat16 = FLAGS.precision == 'bfloat16'

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train, imagenet_eval = [
        imagenet_input.ImageNetInput(is_training=is_training,
                                     data_dir=FLAGS.data_dir,
                                     transpose_input=FLAGS.transpose_input,
                                     use_bfloat16=use_bfloat16)
        for is_training in [True, False]
    ]

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                                (eval_results, elapsed_time))

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                       batches_per_epoch, current_step))

        start_timestamp = time.time(
        )  # This time will include compilation time
        if FLAGS.mode == 'train':
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=FLAGS.train_steps)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=next_checkpoint)
                current_step = next_checkpoint

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be consistently excluded modulo the batch size.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=FLAGS.num_eval_images // FLAGS.eval_batch_size)
                tf.logging.info('Eval results: %s' % eval_results)

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info(
            'Finished training up to step %d. Elapsed seconds %d.' %
            (FLAGS.train_steps, elapsed_time))

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            resnet_classifier.export_savedmodel(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
Exemple #31
0
def main(unused_argv):
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.iterations_per_loop,
        keep_checkpoint_max=5,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            PER_HOST_V2))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data_dir,
        num_parallel_calls=FLAGS.num_parallel_calls,
        use_transpose=FLAGS.use_transpose)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data_dir,
        num_parallel_calls=FLAGS.num_parallel_calls,
        use_transpose=FLAGS.use_transpose)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    steps_per_epoch = NUM_TRAIN_IMAGES // FLAGS.train_batch_size
    start_timestamp = time.time()
    current_epoch = current_step // steps_per_epoch

    if FLAGS.mode == 'train':
        resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                max_steps=FLAGS.train_steps)
        training_time = time.time() - start_timestamp
        tf.logging.info('Finished training in %d seconds' % training_time)

        with tf.gfile.GFile(FLAGS.model_dir + '/total_time_%s.txt' % training_time, 'w') as f:  # pylint: disable=line-too-long
            f.write('Total training time was %s seconds' % training_time)

    elif FLAGS.mode == 'eval':
        results = []

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(FLAGS.model_dir):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                                (eval_results, elapsed_time))

                current_step = int(os.path.basename(ckpt).split('-')[1])
                current_epoch = current_step // steps_per_epoch
                results.append([
                    current_epoch,
                    '{0:.2f}'.format(eval_results['top_1_accuracy'] * 100),
                    '{0:.2f}'.format(eval_results['top_5_accuracy'] * 100),
                ])

                # Terminate eval job when final checkpoint is reached
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)

        with tf.gfile.GFile(FLAGS.model_dir + '/epoch_results_eval.tsv', 'wb') as tsv_file:  # pylint: disable=line-too-long
            writer = csv.writer(tsv_file, delimiter='\t')
            writer.writerow(['epoch', 'top1Accuracy', 'top5Accuracy'])
            writer.writerows(results)

    elif FLAGS.mode == 'train_and_eval':
        results = []
        while current_epoch < 95:
            next_checkpoint = (current_epoch + 1) * steps_per_epoch
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=next_checkpoint)
            current_epoch += 1

            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.' %
                (next_checkpoint, int(time.time() - start_timestamp)))

            # Evaluate the model on the most recent model in --model_dir.
            # Since evaluation happens in batches of --eval_batch_size, some images
            # may be excluded modulo the batch size. As long as the batch size is
            # consistent, the evaluated images are also consistent.
            tf.logging.info('Starting to evaluate.')
            eval_results = resnet_classifier.evaluate(
                input_fn=imagenet_eval.input_fn,
                steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
            tf.logging.info('Eval results: %s' % eval_results)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info('Finished epoch %s at %s time' %
                            (current_epoch, elapsed_time))
            results.append([
                current_epoch,
                elapsed_time / 3600.0,
                '{0:.2f}'.format(eval_results['top_1_accuracy'] * 100),
                '{0:.2f}'.format(eval_results['top_5_accuracy'] * 100),
            ])

        with tf.gfile.GFile(FLAGS.model_dir + '/epoch_results_train_eval.tsv', 'wb') as tsv_file:  # pylint: disable=line-too-long
            writer = csv.writer(tsv_file, delimiter='\t')
            writer.writerow(['epoch', 'hours', 'top1Accuracy', 'top5Accuracy'])
            writer.writerows(results)
    else:
        tf.logging.info('Mode not found.')

    if FLAGS.export_dir is not None:
        # The guide to serve a exported TensorFlow model is at:
        #    https://www.tensorflow.org/serving/serving_basic
        tf.logging.info('Starting to export model.')
        resnet_classifier.export_savedmodel(
            export_dir_base=FLAGS.export_dir,
            serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  # RevNet specific configuration
  revnet_config = {
      "revnet-56": config_.get_hparams_imagenet_56(),
      "revnet-104": config_.get_hparams_imagenet_104()
  }[FLAGS.revnet_config]

  if FLAGS.use_tpu:
    revnet_config.data_format = "channels_last"

  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  # Estimator specific configuration
  config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=True),
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_shards,
          per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.
          PER_HOST_V2),
  )

  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  imagenet_train, imagenet_eval = [
      imagenet_input.ImageNetInput(
          is_training=is_training,
          data_dir=FLAGS.data_dir,
          transpose_input=FLAGS.transpose_input,
          use_bfloat16=False) for is_training in [True, False]
  ]

  revnet_classifier = tf.contrib.tpu.TPUEstimator(
      model_fn=model_fn,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=revnet_config.tpu_batch_size,
      eval_batch_size=revnet_config.tpu_eval_batch_size,
      config=config,
      export_to_tpu=False,
      params={"revnet_config": revnet_config})

  steps_per_epoch = revnet_config.tpu_iters_per_epoch
  eval_steps = revnet_config.tpu_eval_steps

  # pylint: disable=protected-access
  if FLAGS.mode == "eval":
    # Run evaluation when there's a new checkpoint
    for ckpt in evaluation.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info("Starting to evaluate.")
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = revnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info("Eval results: %s. Elapsed seconds: %d" %
                        (eval_results, elapsed_time))

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split("-")[1])
        if current_step >= revnet_config.max_train_iter:
          tf.logging.info(
              "Evaluation finished after training step %d" % current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info(
            "Checkpoint %s no longer exists, skipping checkpoint" % ckpt)

  else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)

    tf.logging.info(
        "Training for %d steps (%.2f epochs in total). Current"
        " step %d." % (revnet_config.max_train_iter,
                       revnet_config.max_train_iter / steps_per_epoch,
                       current_step))

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == "train":
      revnet_classifier.train(
          input_fn=imagenet_train.input_fn,
          max_steps=revnet_config.max_train_iter)

    else:
      assert FLAGS.mode == "train_and_eval"
      while current_step < revnet_config.max_train_iter:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              revnet_config.max_train_iter)
        revnet_classifier.train(
            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
                        (next_checkpoint, int(time.time() - start_timestamp)))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info("Starting to evaluate.")
        eval_results = revnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn, steps=eval_steps)
        tf.logging.info("Eval results: %s" % eval_results)

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info("Finished training up to step %d. Elapsed seconds %d." %
                        (revnet_config.max_train_iter, elapsed_time))

    if FLAGS.export_dir is not None:
      # The guide to serve an exported TensorFlow model is at:
      #    https://www.tensorflow.org/serving/serving_basic
      tf.logging.info("Starting to export model.")
      revnet_classifier.export_savedmodel(
          export_dir_base=FLAGS.export_dir,
          serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
Exemple #33
0
def main(argv):
    del argv

    if FLAGS.use_tpu:
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring '
                    '--tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        tpu_config=tpu_config.TPUConfig(
            num_shards=FLAGS.num_shards,
            iterations_per_loop=FLAGS.iterations_per_loop))

    # Set module-level global variable so that model_fn and input_fn can be
    # identical for each different kind of dataset and model
    global dataset, model
    if FLAGS.dataset == 'mnist':
        dataset = mnist_input
        model = mnist_model
    elif FLAGS.dataset == 'cifar':
        dataset = cifar_input
        model = cifar_model
    else:
        raise ValueError('Invalid dataset: %s' % FLAGS.dataset)

    # TPU-based estimator used for TRAIN and EVAL
    est = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                     use_tpu=FLAGS.use_tpu,
                                     config=config,
                                     train_batch_size=FLAGS.batch_size,
                                     eval_batch_size=FLAGS.batch_size)

    # CPU-based estimator used for PREDICT (generating images)
    cpu_est = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                         use_tpu=False,
                                         config=config,
                                         predict_batch_size=_NUM_VIZ_IMAGES)

    tf.gfile.MakeDirs(os.path.join(FLAGS.model_dir, 'generated_images'))

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    tf.logging.info('Starting training for %d steps, current step: %d' %
                    (FLAGS.train_steps, current_step))
    while current_step < FLAGS.train_steps:
        next_checkpoint = min(current_step + FLAGS.train_steps_per_eval,
                              FLAGS.train_steps)
        est.train(input_fn=generate_input_fn(True), max_steps=next_checkpoint)
        current_step = next_checkpoint
        tf.logging.info('Finished training step %d' % current_step)

        if FLAGS.eval_loss:
            # Evaluate loss on test set
            metrics = est.evaluate(input_fn=generate_input_fn(False),
                                   steps=dataset.NUM_EVAL_IMAGES //
                                   FLAGS.batch_size)
            tf.logging.info('Finished evaluating')
            tf.logging.info(metrics)

        # Render some generated images
        generated_iter = cpu_est.predict(input_fn=noise_input_fn)
        images = [p['generated_images'][:, :, :] for p in generated_iter]
        assert len(images) == _NUM_VIZ_IMAGES
        image_rows = [
            np.concatenate(images[i:i + 10], axis=0)
            for i in range(0, _NUM_VIZ_IMAGES, 10)
        ]
        tiled_image = np.concatenate(image_rows, axis=1)

        img = dataset.convert_array_to_image(tiled_image)

        step_string = str(current_step).zfill(5)
        file_obj = tf.gfile.Open(
            os.path.join(FLAGS.model_dir, 'generated_images',
                         'gen_%s.png' % (step_string)), 'w')
        img.save(file_obj, format='png')
        tf.logging.info('Finished generating images')
Exemple #34
0
def main(unused_argv):
    if FLAGS.task_name == 'svhn':
        FLAGS.input_image_size = 32
        FLAGS.small_image_model = True
        FLAGS.num_label_classes = 10
    if FLAGS.num_train_images is None:
        FLAGS.num_train_images = task_info.get_num_train_images(
            FLAGS.task_name)
    if FLAGS.num_eval_images is None:
        FLAGS.num_eval_images = task_info.get_num_eval_images(FLAGS.task_name)
    if FLAGS.num_test_images is None and FLAGS.task_name != 'imagenet':
        FLAGS.num_test_images = task_info.get_num_test_images(FLAGS.task_name)

    steps_per_epoch = (FLAGS.num_train_images /
                       (FLAGS.train_batch_size * FLAGS.label_data_sample_prob))
    if FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval':
        tf.gfile.MakeDirs(FLAGS.model_dir)
        flags_dict = tf.app.flags.FLAGS.flag_values_dict()
        with tf.gfile.Open(os.path.join(FLAGS.model_dir, 'FLAGS.json'),
                           'w') as ouf:
            json.dump(flags_dict, ouf)
    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        _, _, input_image_size, _ = efficientnet_builder.efficientnet_params(
            FLAGS.model_name)
        FLAGS.input_image_size = input_image_size
    if FLAGS.train_last_step_num == -1:
        FLAGS.train_last_step_num = FLAGS.train_steps
    if FLAGS.train_ratio != 1:
        FLAGS.train_last_step_num *= FLAGS.train_ratio
        FLAGS.train_steps *= FLAGS.train_ratio
        FLAGS.train_last_step_num = int(FLAGS.train_last_step_num)
        FLAGS.train_steps = int(FLAGS.train_steps)

    if (FLAGS.tpu or FLAGS.use_tpu) and not FLAGS.master:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    if FLAGS.use_tpu:
        tpu_config = tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.
            PER_HOST_V2)
    else:
        tpu_config = tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.
            PER_HOST_V2)
    config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=max(FLAGS.save_checkpoints_steps, FLAGS.iterations_per_loop),
        log_step_count_steps=FLAGS.log_step_count_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tpu_config)  # pylint: disable=line-too-long
    # Initializes model parameters.
    params = dict(steps_per_epoch=steps_per_epoch,
                  use_bfloat16=FLAGS.use_bfloat16)
    est = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=8,
        params=params)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.label_data_dir == FAKE_DATA_DIR:
        tf.logging.info('Using fake dataset.')
    else:
        tf.logging.info('Using dataset: %s', FLAGS.label_data_dir)

    train_data = data_input.DataInput(is_training=True,
                                      data_dir=FLAGS.label_data_dir,
                                      transpose_input=FLAGS.transpose_input,
                                      cache=FLAGS.use_cache,
                                      image_size=input_image_size,
                                      use_bfloat16=FLAGS.use_bfloat16)
    if FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval':
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_last_step_num,
            FLAGS.train_last_step_num / params['steps_per_epoch'],
            current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            est.train(input_fn=train_data.input_fn,
                      max_steps=FLAGS.train_last_step_num,
                      hooks=[])
    elif FLAGS.mode == 'eval':
        input_fn_mapping = {}
        for subset in ['dev', 'test']:
            input_fn_mapping[subset] = data_input.DataInput(
                is_training=False,
                data_dir=FLAGS.label_data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=False,
                image_size=input_image_size,
                use_bfloat16=FLAGS.use_bfloat16,
                subset=subset).input_fn
            if subset == 'dev':
                num_images = FLAGS.num_eval_images
            else:
                num_images = FLAGS.num_test_images
            eval_results = est.evaluate(input_fn=input_fn_mapping[subset],
                                        steps=num_images //
                                        FLAGS.eval_batch_size)
            tf.logging.info('%s, results: %s', subset, eval_results)
    elif FLAGS.mode == 'predict':
        predict_label.run_prediction(est)
    else:
        assert False
 def test_run_train_op_and_saves_at_the_end(self):
   est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
   est.fit(dummy_input_fn, steps=5)
   self.assertEqual(
       5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
Exemple #36
0
def main(unused_argv):

    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        if FLAGS.model_name.startswith('efficientnet-edgetpu'):
            _, _, input_image_size, _ = efficientnet_edgetpu_builder.efficientnet_edgetpu_params(
                FLAGS.model_name)
        elif FLAGS.model_name.startswith('efficientnet-tpu'):
            _, _, input_image_size, _ = efficientnet_tpu_builder.efficientnet_tpu_params(
                FLAGS.model_name)
        elif FLAGS.model_name.startswith('efficientnet'):
            _, _, input_image_size, _ = efficientnet_builder.efficientnet_params(
                FLAGS.model_name)
        else:
            raise ValueError(
                'input_image_size must be set except for EfficientNet')

    # For imagenet dataset, include background label if number of output classes
    # is 1001
    include_background_label = (FLAGS.num_label_classes == 1001)

    if FLAGS.tpu or FLAGS.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    if FLAGS.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long
    # Initializes model parameters.
    params = dict(steps_per_epoch=FLAGS.num_train_images /
                  FLAGS.train_batch_size,
                  use_bfloat16=FLAGS.use_bfloat16)
    est = tf.contrib.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu,
                                      model_fn=model_fn,
                                      config=config,
                                      train_batch_size=FLAGS.train_batch_size,
                                      eval_batch_size=FLAGS.eval_batch_size,
                                      export_to_tpu=FLAGS.export_to_tpu,
                                      params=params)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    def build_imagenet_input(is_training):
        """Generate ImageNetInput for training and eval."""
        if FLAGS.bigtable_instance:
            tf.logging.info('Using Bigtable dataset, table %s',
                            FLAGS.bigtable_table)
            select_train, select_eval = _select_tables_from_flags()
            return imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=FLAGS.use_bfloat16,
                transpose_input=FLAGS.transpose_input,
                selection=select_train if is_training else select_eval,
                include_background_label=include_background_label,
                autoaugment_name=FLAGS.autoaugment_name)
        else:
            if FLAGS.data_dir == FAKE_DATA_DIR:
                tf.logging.info('Using fake dataset.')
            else:
                tf.logging.info('Using dataset: %s', FLAGS.data_dir)

            return imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                image_size=input_image_size,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=FLAGS.use_bfloat16,
                include_background_label=include_background_label,
                autoaugment_name=FLAGS.autoaugment_name)

    imagenet_train = build_imagenet_input(is_training=True)
    imagenet_eval = build_imagenet_input(is_training=False)

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = est.evaluate(input_fn=imagenet_eval.input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / params['steps_per_epoch'], current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if FLAGS.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, FLAGS.iterations_per_loop)))
            est.train(input_fn=imagenet_train.input_fn,
                      max_steps=FLAGS.train_steps,
                      hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                est.train(input_fn=imagenet_train.input_fn,
                          max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = est.evaluate(input_fn=imagenet_eval.input_fn,
                                            steps=FLAGS.num_eval_images //
                                            FLAGS.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                FLAGS.train_steps, elapsed_time)
    if FLAGS.export_dir:
        export(est, FLAGS.export_dir, input_image_size)
Exemple #37
0
def main(unused_argv):
    params = params_dict.ParamsDict(mnasnet_config.MNASNET_CFG,
                                    mnasnet_config.MNASNET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)

    params = flags_to_params.override_params_from_input_flags(params, FLAGS)

    additional_params = {
        'steps_per_epoch': params.num_train_images / params.train_batch_size,
        'quantized_training': FLAGS.quantized_training,
    }

    params = params_dict.override_params_dict(params,
                                              additional_params,
                                              is_strict=False)

    params.validate()
    params.lock()

    if FLAGS.tpu or params.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    if params.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, params.iterations_per_loop)
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=params.iterations_per_loop,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    # Validates Flags.
    if params.precision == 'bfloat16' and params.use_keras:
        raise ValueError(
            'Keras layers do not have full support to bfloat16 activation training.'
            ' You have set precision as %s and use_keras as %s' %
            (params.precision, params.use_keras))

    # Initializes model parameters.
    mnasnet_est = tf.contrib.tpu.TPUEstimator(
        use_tpu=params.use_tpu,
        model_fn=mnasnet_model_fn,
        config=config,
        train_batch_size=params.train_batch_size,
        eval_batch_size=params.eval_batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params.as_dict())

    if FLAGS.mode == 'export_only':
        export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
        return

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=False,
                transpose_input=params.transpose_input,
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=params.transpose_input,
                cache=params.use_cache and is_training,
                image_size=params.input_image_size,
                num_parallel_calls=params.num_parallel_calls,
                use_bfloat16=(params.precision == 'bfloat16'))
            for is_training in [True, False]
        ]

    if FLAGS.mode == 'eval':
        eval_steps = params.num_eval_images // params.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= params.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

        if FLAGS.export_dir:
            export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(  # pylint: disable=protected-access
            FLAGS.model_dir)

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', params.train_steps,
            params.train_steps / params.steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if params.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, params.iterations_per_loop)))
            mnasnet_est.train(input_fn=imagenet_train.input_fn,
                              max_steps=params.train_steps,
                              hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < params.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      params.train_steps)
                mnasnet_est.train(input_fn=imagenet_train.input_fn,
                                  max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=params.num_eval_images // params.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                params.train_steps, elapsed_time)
            if FLAGS.export_dir:
                export(mnasnet_est, FLAGS.export_dir, params,
                       FLAGS.post_quantize)