示例#1
0
def main(unused_argv):
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data_dir)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data_dir)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                   batches_per_epoch, current_step))
    #start_timestamp = time.time()
    #while current_step < FLAGS.train_steps:
    # Train for up to steps_per_eval number of steps. At the end of training, a
    # checkpoint will be written to --model_dir.
    #  next_checkpoint = min(current_step + FLAGS.steps_per_eval,
    #                        FLAGS.train_steps)
    resnet_classifier.train(input_fn=imagenet_train.input_fn,
                            max_steps=FLAGS.train_steps)
示例#2
0
def main(unused_argv):
  assert FLAGS.data is not None, 'Provide training data path via --data.'

  batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE
  training_steps_per_epoch = int(APPROX_IMAGENET_TRAINING_IMAGES / batch_size)
  validation_steps = int(IMAGENET_VALIDATION_IMAGES // batch_size)

  model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
  logging.info('Saving tensorboard summaries at %s', model_dir)

  logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local')
  resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
  tf.contrib.distribute.initialize_tpu_system(resolver)
  strategy = tf.contrib.distribute.TPUStrategy(resolver)

  logging.info('Use bfloat16: %s.', USE_BFLOAT16)
  logging.info('Use global batch size: %s.', batch_size)
  logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)
  logging.info('Training model using data in directory "%s".', FLAGS.data)

  with strategy.scope():
    logging.info('Building Keras ResNet-50 model')
    model = resnet_model.ResNet50(num_classes=NUM_CLASSES)

    logging.info('Compiling model.')
    metrics = ['sparse_categorical_accuracy']

    if FLAGS.eval_top_5_accuracy:
      metrics.append(sparse_top_k_categorical_accuracy)

    model.compile(
        optimizer=gradient_descent.SGD(
            learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
        loss='sparse_categorical_crossentropy',
        metrics=metrics)

  imagenet_train = imagenet_input.ImageNetInput(
      is_training=True, data_dir=FLAGS.data, batch_size=batch_size,
      use_bfloat16=USE_BFLOAT16)
  imagenet_eval = imagenet_input.ImageNetInput(
      is_training=False, data_dir=FLAGS.data, batch_size=batch_size,
      use_bfloat16=USE_BFLOAT16)

  lr_schedule_cb = LearningRateBatchScheduler(
      schedule=learning_rate_schedule_wrapper(training_steps_per_epoch))
  tensorboard_cb = eval_utils.TensorBoardWithValidation(
      log_dir=model_dir,
      validation_imagenet_input=imagenet_eval,
      validation_steps=validation_steps,
      validation_epochs=[30, 60, 90])

  training_callbacks = [lr_schedule_cb, tensorboard_cb]

  model.fit(
      imagenet_train.input_fn(),
      epochs=EPOCHS,
      steps_per_epoch=training_steps_per_epoch,
      callbacks=training_callbacks)

  model_saving_utils.save_model(model, model_dir, WEIGHTS_TXT)
示例#3
0
    def build_imagenet_input(is_training):
        """Generate ImageNetInput for training and eval."""
        if FLAGS.bigtable_instance:
            tf.logging.info('Using Bigtable dataset, table %s',
                            FLAGS.bigtable_table)
            select_train, select_eval = _select_tables_from_flags()
            return imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=FLAGS.use_bfloat16,
                transpose_input=FLAGS.transpose_input,
                selection=select_train if is_training else select_eval,
                include_background_label=include_background_label,
                autoaugment_name=FLAGS.autoaugment_name)
        else:
            if FLAGS.data_dir == FAKE_DATA_DIR:
                tf.logging.info('Using fake dataset.')
            else:
                tf.logging.info('Using dataset: %s', FLAGS.data_dir)

            return imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                image_size=input_image_size,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=FLAGS.use_bfloat16,
                include_background_label=include_background_label,
                autoaugment_name=FLAGS.autoaugment_name)
示例#4
0
def representative_dataset_gen():
  """Gets a python generator of image numpy arrays for ImageNet."""
  params = dict(batch_size=FLAGS.batch_size)
  imagenet_eval = imagenet_input.ImageNetInput(
      is_training=False,
      data_dir=FLAGS.data_dir,
      transpose_input=False,
      cache=False,
      image_size=FLAGS.image_size,
      num_parallel_calls=1,
      use_bfloat16=False)

  data = imagenet_eval.input_fn(params)

  def preprocess_map_fn(images, labels):
    del labels
    if FLAGS.input_name == "truediv":
      images -= tf.constant(
          imagenet_input.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype)
      images /= tf.constant(
          imagenet_input.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype)
    return images

  data = data.map(preprocess_map_fn)
  iterator = data.make_one_shot_iterator()
  for _ in range(FLAGS.num_steps):
    # In eager context, we can get a python generator from a dataset iterator.
    images = iterator.get_next()
    yield [images]
示例#5
0
def main(argv):
  logging.info('Building Keras ResNet-50 model.')
  model = tf.keras.applications.resnet50.ResNet50(
      include_top=True,
      weights=None,
      input_tensor=None,
      input_shape=None,
      pooling=None,
      classes=NUM_CLASSES)

  per_core_batch_size = 128
  num_cores = 8
  batch_size = per_core_batch_size * num_cores

  if FLAGS.tpu is not None:
    logging.info('Converting from CPU to TPU model.')
    strategy = keras_support.TPUDistributionStrategy(
        num_cores_per_host=num_cores)
    model = keras_support.tpu_model(model, strategy=strategy,
                                    tpu_name_or_address=FLAGS.tpu)

  logging.info('Compiling model.')
  model.compile(
      optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
      loss='sparse_categorical_crossentropy',
      metrics=['sparse_categorical_accuracy'])

  if FLAGS.data is None:
    training_images = np.random.randn(
        batch_size, IMAGE_SIZE, IMAGE_SIZE, 3).astype(np.float32)
    training_labels = np.random.randint(NUM_CLASSES, size=batch_size,
                                        dtype=np.int32)
    logging.info('Training model using synthetica data.')
    num_epochs = 100  # TPUs are very fast when running a single step per epoch!
    model.fit(training_images, training_labels, epochs=num_epochs,
              batch_size=batch_size)
    logging.info('Evaluating the model on synthetic data.')
    model.evaluate(training_images, training_labels, verbose=0)
  else:

    imagenet_train, imagenet_eval = [imagenet_input.ImageNetInput(
        is_training=is_training,
        data_dir=FLAGS.data,
        per_core_batch_size=per_core_batch_size)
                                     for is_training in [True, False]]
    logging.info('Training model using real data in directory "%s".',
                 FLAGS.data)
    num_epochs = 90  # Standard imagenet training regime.
    model.fit(imagenet_train.input_fn,
              epochs=num_epochs,
              steps_per_epoch=int(APPROX_IMAGENET_TRAINING_IMAGES / batch_size))
    logging.info('Evaluating the model on the validation dataset.')
    model.evaluate(imagenet_eval.input_fn)
示例#6
0
    def build_imagenet_input(self, is_training):
        """Generate ImageNetInput for training and eval."""
        # For imagenet dataset, include background label if number of output classes
        # is 1001
        include_background_label = (FLAGS.num_label_classes == 1001)

        tf.logging.info('Using dataset: %s', FLAGS.data_dir)

        return imagenet_input.ImageNetInput(
            is_training=is_training,
            data_dir=FLAGS.data_dir,
            transpose_input=FLAGS.transpose_input,
            cache=FLAGS.use_cache and is_training,
            image_size=FLAGS.input_image_size,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_bfloat16=FLAGS.use_bfloat16,
            include_background_label=include_background_label)
示例#7
0
    def build_imagenet_input(is_training):
        """Generate ImageNetInput for training and eval."""
        if FLAGS.bigtable_instance:
            logging.info('Using Bigtable dataset, table %s',
                         FLAGS.bigtable_table)
            select_train, select_eval = _select_tables_from_flags()
            return imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=FLAGS.use_bfloat16,
                transpose_input=FLAGS.transpose_input,
                selection=select_train if is_training else select_eval,
                num_label_classes=FLAGS.num_label_classes,
                include_background_label=include_background_label,
                augment_name=FLAGS.augment_name,
                mixup_alpha=FLAGS.mixup_alpha,
                randaug_num_layers=FLAGS.randaug_num_layers,
                randaug_magnitude=FLAGS.randaug_magnitude,
                resize_method=resize_method)
        else:
            if FLAGS.data_dir == FAKE_DATA_DIR:
                logging.info('Using fake dataset.')
            else:
                logging.info('Using dataset: %s', FLAGS.data_dir)

            return imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                image_size=input_image_size,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=FLAGS.use_bfloat16,
                num_label_classes=FLAGS.num_label_classes,
                include_background_label=include_background_label,
                augment_name=FLAGS.augment_name,
                mixup_alpha=FLAGS.mixup_alpha,
                randaug_num_layers=FLAGS.randaug_num_layers,
                randaug_magnitude=FLAGS.randaug_magnitude,
                resize_method=resize_method,
                holdout_shards=FLAGS.holdout_shards)
示例#8
0
def build_imagenet_input(context, is_training):
    input_image_size = model_builder_factory.get_model_input_size(context.get_hparam("model_name"))
    include_background_label = (context.get_hparam("num_label_classes") == 1001)
    """Generate ImageNetInput for training and eval."""
    data_dir = context.get_data_config().get("data_dir")
    logging.info("Using dataset: %s", data_dir)

    return imagenet_input.ImageNetInput(
        is_training=is_training,
        data_dir=data_dir,
        transpose_input=False,#context.get_hparam("transpose_input"),
        cache=False,#context.get_hparam("use_cache") and is_training,
        image_size=input_image_size,
        num_parallel_calls=context.get_hparam("num_parallel_calls"),
        num_label_classes=context.get_hparam("num_label_classes"),
        include_background_label=include_background_label,
        #augment_name=context.get_hparam("augment_name"),
        mixup_alpha=context.get_hparam("mixup_alpha"),
        randaug_num_layers=context.get_hparam("randaug_num_layers"),
        randaug_magnitude=context.get_hparam("randaug_magnitude"),
        resize_method=None,
        use_bfloat16=False,
        context=context,
    )
示例#9
0
def main(unused_argv):
    tf.enable_v2_behavior()
    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    batch_size = PER_CORE_BATCH_SIZE * FLAGS.num_cores
    steps_per_epoch = FLAGS.steps_per_epoch or (int(
        APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
    steps_per_eval = int(1.0 *
                         math.ceil(IMAGENET_VALIDATION_IMAGES / batch_size))
    logging.info('Saving checkpoints at %s', model_dir)
    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')

    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data,
        batch_size=PER_CORE_BATCH_SIZE,
        use_bfloat16=_USE_BFLOAT16)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data,
        batch_size=PER_CORE_BATCH_SIZE,
        use_bfloat16=_USE_BFLOAT16)
    train_dataset = strategy.experimental_distribute_datasets_from_function(
        imagenet_train.input_fn)
    test_dataset = strategy.experimental_distribute_datasets_from_function(
        imagenet_eval.input_fn)

    if _USE_BFLOAT16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    with strategy.scope():
        logging.info('Building Keras ResNet-50 model')
        model = resnet_model.ResNet50(num_classes=NUM_CLASSES)
        base_lr = _BASE_LEARNING_RATE * batch_size / 256
        optimizer = tf.keras.optimizers.SGD(
            learning_rate=ResnetLearningRateSchedule(steps_per_epoch, base_lr),
            momentum=0.9,
            nesterov=True)
        training_loss = tf.keras.metrics.Mean('training_loss',
                                              dtype=tf.float32)
        training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'training_accuracy', dtype=tf.float32)
        test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)
        logging.info('Finished building Keras ResNet-50 model')

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(model_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    # Create summary writers
    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(model_dir, 'summaries/train'))
    test_summary_writer = tf.summary.create_file_writer(
        os.path.join(model_dir, 'summaries/test'))

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                predictions = model(images, training=True)
                if _USE_BFLOAT16:
                    predictions = tf.cast(predictions, tf.float32)

                # Loss calculations.
                #
                # Part 1: Prediction loss.
                prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, predictions)
                loss1 = tf.reduce_mean(prediction_loss)
                # Part 2: Model weights regularization
                loss2 = tf.reduce_sum(model.losses)

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = loss1 + loss2
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            training_loss.update_state(loss)
            training_accuracy.update_state(labels, predictions)

        strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator):
        """Evaluation StepFn."""
        def step_fn(inputs):
            images, labels = inputs
            predictions = model(images, training=False)
            if _USE_BFLOAT16:
                predictions = tf.cast(predictions, tf.float32)
            loss = tf.keras.losses.sparse_categorical_crossentropy(
                labels, predictions)
            loss = safe_mean(loss)
            test_loss.update_state(loss)
            test_accuracy.update_state(labels, predictions)

        strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    for epoch in range(initial_epoch, FLAGS.num_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        with train_summary_writer.as_default():
            for step in range(steps_per_epoch):
                if step % 20 == 0:
                    logging.info('Running step %s in epoch %s', step, epoch)
                train_step(train_iterator)
            tf.summary.scalar('loss',
                              training_loss.result(),
                              step=optimizer.iterations)
            tf.summary.scalar('accuracy',
                              training_accuracy.result(),
                              step=optimizer.iterations)
            logging.info('Training loss: %s, accuracy: %s%%',
                         round(training_loss.result(), 4),
                         round(training_accuracy.result() * 100, 2))
            training_loss.reset_states()
            training_accuracy.reset_states()

        with test_summary_writer.as_default():
            test_iterator = iter(test_dataset)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator)
            tf.summary.scalar('loss',
                              test_loss.result(),
                              step=optimizer.iterations)
            tf.summary.scalar('accuracy',
                              test_accuracy.result(),
                              step=optimizer.iterations)
            logging.info('Test loss: %s, accuracy: %s%%',
                         round(test_loss.result(), 4),
                         round(test_accuracy.result() * 100, 2))
            test_loss.reset_states()
            test_accuracy.reset_states()

        checkpoint_name = checkpoint.save(os.path.join(model_dir,
                                                       'checkpoint'))
        logging.info('Saved checkpoint to %s', checkpoint_name)
def main(unused_argv):
    """Starts a ResNet training session."""
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    # Estimator looks at the master it connects to for MonitoredTrainingSession
    # by reading the `TF_CONFIG` environment variable.
    tf_config_env = {
        'session_master': tpu_cluster_resolver.get_master(),
        'eval_session_master': tpu_cluster_resolver.get_master()
    }
    os.environ['TF_CONFIG'] = json.dumps(tf_config_env)

    steps_per_run_train = _NUM_TRAIN_IMAGES // (FLAGS.train_batch_size *
                                                FLAGS.num_cores)
    steps_per_run_eval = _NUM_EVAL_IMAGES // (FLAGS.eval_batch_size *
                                              FLAGS.num_cores)
    steps_per_eval = steps_per_run_train

    train_distribution = tf.contrib.distribute.TPUStrategy(
        tpu_cluster_resolver, steps_per_run=steps_per_run_train)
    eval_distribution = tf.contrib.distribute.TPUStrategy(
        tpu_cluster_resolver, steps_per_run=steps_per_run_eval)
    config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir,
                                    train_distribute=train_distribution,
                                    eval_distribute=eval_distribution,
                                    save_checkpoints_steps=steps_per_eval,
                                    save_checkpoints_secs=None,
                                    keep_checkpoint_max=10)

    resnet_estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)

    train_input, eval_input = [
        imagenet_input.ImageNetInput(
            is_training=is_training,
            data_dir=FLAGS.data_dir,
            transpose_input=FLAGS.transpose_input,
            use_bfloat16=(FLAGS.precision == 'bfloat16'))
        for is_training in [True, False]
    ]

    try:
        current_step = resnet_estimator.get_variable_value(
            tf.GraphKeys.GLOBAL_STEP)
    except ValueError:
        current_step = 0

    while current_step < _TRAIN_STEPS:
        next_checkpoint = min(current_step + steps_per_eval, _TRAIN_STEPS)

        resnet_estimator.train(
            input_fn=lambda: train_input.input_fn(  # pylint: disable=g-long-lambda
                {'batch_size': FLAGS.train_batch_size}),
            max_steps=next_checkpoint)
        current_step = next_checkpoint

        eval_results = resnet_estimator.evaluate(
            input_fn=lambda: eval_input.input_fn(  # pylint: disable=g-long-lambda
                {'batch_size': FLAGS.eval_batch_size}),
            steps=_NUM_EVAL_IMAGES //
            (FLAGS.eval_batch_size * FLAGS.num_cores))

        tf.logging.info('Eval results: %s' % eval_results)
示例#11
0
def main(unused_argv):
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=max(600, FLAGS.iterations_per_loop),
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
        'Invalid value for --precision flag; must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', FLAGS.precision)
    use_bfloat16 = FLAGS.precision == 'bfloat16'

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train, imagenet_eval = [
        imagenet_input.ImageNetInput(is_training=is_training,
                                     data_dir=FLAGS.data_dir,
                                     transpose_input=FLAGS.transpose_input,
                                     use_bfloat16=use_bfloat16)
        for is_training in [True, False]
    ]

    if FLAGS.mode == 'eval':
        eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                                (eval_results, elapsed_time))

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                       batches_per_epoch, current_step))

        start_timestamp = time.time(
        )  # This time will include compilation time
        if FLAGS.mode == 'train':
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=FLAGS.train_steps)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=next_checkpoint)
                current_step = next_checkpoint

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be consistently excluded modulo the batch size.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
                tf.logging.info('Eval results: %s' % eval_results)

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info(
            'Finished training up to step %d. Elapsed seconds %d.' %
            (FLAGS.train_steps, elapsed_time))

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            resnet_classifier.export_savedmodel(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
示例#12
0
def main(unused_argv):
    assert FLAGS.data is not None, 'Provide training data path via --data.'
    tf.enable_v2_behavior()
    tf.compat.v1.disable_eager_execution()  # todo

    batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE

    training_steps_per_epoch = FLAGS.steps_per_epoch or (int(
        APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
    validation_steps = int(
        math.ceil(1.0 * IMAGENET_VALIDATION_IMAGES / batch_size))

    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    logging.info('Saving tensorboard summaries at %s', model_dir)

    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

    logging.info('Use bfloat16: %s.', USE_BFLOAT16)
    logging.info('Use global batch size: %s.', batch_size)
    logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)
    logging.info('Training model using data in directory "%s".', FLAGS.data)

    with strategy.scope():
        logging.info('Building Keras ResNet-50 model')
        model = resnet_model.ResNet50(num_classes=NUM_CLASSES)
        # model = keras_applications.mobilenet_v2.MobileNetV2(classes=NUM_CLASSES, weights=None)

        logging.info('Compiling model.')
        metrics = ['sparse_categorical_accuracy']

        if FLAGS.eval_top_5_accuracy:
            metrics.append(sparse_top_k_categorical_accuracy)

        model.compile(optimizer=tf.keras.optimizers.SGD(
            learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
                      loss='sparse_categorical_crossentropy',
                      metrics=metrics)

    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data,
                                                  batch_size=batch_size,
                                                  use_bfloat16=USE_BFLOAT16)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data,
                                                 batch_size=batch_size,
                                                 use_bfloat16=USE_BFLOAT16)

    lr_schedule_cb = LearningRateBatchScheduler(
        schedule=learning_rate_schedule_wrapper(training_steps_per_epoch))
    tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=model_dir)

    training_callbacks = [lr_schedule_cb, tensorboard_cb]

    model.fit(imagenet_train.input_fn(),
              epochs=FLAGS.num_epochs,
              steps_per_epoch=training_steps_per_epoch,
              callbacks=training_callbacks,
              validation_data=imagenet_eval.input_fn(),
              validation_steps=validation_steps,
              validation_freq=5)

    model_saving_utils.save_model(model, model_dir, WEIGHTS_TXT)
示例#13
0
def main(argv):
  logging.info('Building Keras ResNet-50 model')
  model = resnet_model.ResNet50(num_classes=NUM_CLASSES)

  if FLAGS.use_tpu:
    logging.info('Converting from CPU to TPU model.')
    resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver)
    model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)

  logging.info('Compiling model.')
  model.compile(
      optimizer=tf.keras.optimizers.SGD(lr=BASE_LEARNING_RATE,
                                        momentum=0.9,
                                        nesterov=True),
      loss='sparse_categorical_crossentropy',
      metrics=['sparse_categorical_accuracy'])

  if FLAGS.data is None:
    training_images = np.random.randn(
        BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3).astype(np.float32)
    training_labels = np.random.randint(NUM_CLASSES, size=BATCH_SIZE,
                                        dtype=np.int32)
    logging.info('Training model using synthetica data.')
    model.fit(
        training_images,
        training_labels,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE)
    logging.info('Evaluating the model on synthetic data.')
    model.evaluate(training_images, training_labels, verbose=0)
  else:
    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data,
        per_core_batch_size=PER_CORE_BATCH_SIZE)
    logging.info('Training model using real data in directory "%s".',
                 FLAGS.data)
    # If evaluating top 5 accuracy, we feed the inputs from a Python generator,
    # so we need to build a single batch for all of the cores, which will be
    # split on TPU.
    per_core_batch_size = (
        BATCH_SIZE if FLAGS.eval_top_5_accuracy else PER_CORE_BATCH_SIZE)
    imagenet_validation = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data,
        per_core_batch_size=per_core_batch_size)

    callbacks = [
        LearningRateBatchScheduler(schedule=learning_rate_schedule),
        eval_utils.TensorBoardWithValidation(
            log_dir=model_dir,
            validation_imagenet_input=imagenet_validation,
            validation_steps=VALIDATION_STEPS,
            validation_epochs=[30, 60, 90],
            eval_top_k_accuracy=FLAGS.eval_top_5_accuracy),
    ]

    model.fit(imagenet_train.input_fn,
              epochs=EPOCHS,
              steps_per_epoch=TRAINING_STEPS_PER_EPOCH,
              callbacks=callbacks)

    if HAS_H5PY:
      weights_file = os.path.join(model_dir, WEIGHTS_TXT)
      logging.info('Save weights into %s', weights_file)
      model.save_weights(weights_file, overwrite=True)
示例#14
0
def main(argv):
    logging.info('Building Keras ResNet-50 model')
    model = tf.keras.applications.resnet50.ResNet50(include_top=True,
                                                    weights=None,
                                                    input_tensor=None,
                                                    input_shape=None,
                                                    pooling=None,
                                                    classes=NUM_CLASSES)

    if FLAGS.use_tpu:
        logging.info('Converting from CPU to TPU model.')
        resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver)
        model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
        session_master = resolver.master()
    else:
        session_master = ''

    logging.info('Compiling model.')
    model.compile(optimizer=tf.keras.optimizers.SGD(lr=BASE_LEARNING_RATE,
                                                    momentum=0.9,
                                                    nesterov=True),
                  loss='sparse_categorical_crossentropy',
                  metrics=['sparse_categorical_accuracy'])

    callbacks = [LearningRateBatchScheduler(schedule=learning_rate_schedule)]
    if FLAGS.model_dir:
        callbacks.append(
            tf.keras.callbacks.TensorBoard(log_dir=FLAGS.model_dir))

    if FLAGS.data is None:
        training_images = np.random.randn(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE,
                                          3).astype(np.float32)
        training_labels = np.random.randint(NUM_CLASSES,
                                            size=BATCH_SIZE,
                                            dtype=np.int32)
        logging.info('Training model using synthetica data.')
        model.fit(training_images,
                  training_labels,
                  epochs=EPOCHS,
                  batch_size=BATCH_SIZE,
                  callbacks=callbacks)
        logging.info('Evaluating the model on synthetic data.')
        model.evaluate(training_images, training_labels, verbose=0)
    else:
        imagenet_train = imagenet_input.ImageNetInput(
            is_training=True,
            use_bfloat16=FLAGS.use_bfloat16,
            data_dir=FLAGS.data,
            per_core_batch_size=PER_CORE_BATCH_SIZE)
        logging.info('Training model using real data in directory "%s".',
                     FLAGS.data)
        model.fit(imagenet_train.input_fn,
                  epochs=EPOCHS,
                  steps_per_epoch=TRAINING_STEPS_PER_EPOCH,
                  callbacks=callbacks)

        logging.info('Evaluating the model on the validation dataset.')
        if FLAGS.eval_top_5_accuracy:
            logging.info('Evaluating top 1 and top 5 accuracy using a Python '
                         'generator.')
            # We feed the inputs from a Python generator, so we need to build a single
            # batch for all of the cores, which will be split on TPU.
            imagenet_eval = imagenet_input.ImageNetInput(
                is_training=False,
                use_bfloat16=FLAGS.use_bfloat16,
                data_dir=FLAGS.data,
                per_core_batch_size=BATCH_SIZE)
            score = eval_utils.multi_top_k_accuracy(
                model, imagenet_eval.evaluation_generator(K.get_session()),
                EVAL_STEPS)
        else:
            imagenet_eval = imagenet_input.ImageNetInput(
                is_training=False,
                use_bfloat16=FLAGS.use_bfloat16,
                data_dir=FLAGS.data,
                per_core_batch_size=PER_CORE_BATCH_SIZE)
            score = model.evaluate(imagenet_eval.input_fn,
                                   steps=EVAL_STEPS,
                                   verbose=1)
        print('Evaluation score', score)

        if HAS_H5PY:
            weights_file = os.path.join(
                FLAGS.model_dir if FLAGS.model_dir else '/tmp', WEIGHTS_TXT)
            logging.info('Save weights into %s', weights_file)
            model.save_weights(weights_file, overwrite=True)
示例#15
0
def main(unused_argv):
  """tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu if (FLAGS.tpu or FLAGS.use_tpu) else '',
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project)"""

  if FLAGS.use_async_checkpointing:
    save_checkpoints_steps = None
  else:
    save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
  """config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      log_step_count_steps=FLAGS.log_step_count_steps,
      session_config=tf.ConfigProto(
          graph_options=tf.GraphOptions(
              rewrite_options=rewriter_config_pb2.RewriterConfig(
                  disable_meta_optimizer=True))),
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
          .PER_HOST_V2))  # pylint: disable=line-too-long"""
  if FLAGS.num_gpus == 1:
    distribution_strategy = None
  else:
    distribution_strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=FLAGS.num_gpus)
  config = tf.estimator.RunConfig(
      model_dir=FLAGS.model_dir,
      train_distribute=distribution_strategy,
      save_checkpoints_steps=save_checkpoints_steps,
      log_step_count_steps=FLAGS.log_step_count_steps,
      session_config=tf.ConfigProto(
          graph_options=tf.GraphOptions(
              rewrite_options=rewriter_config_pb2.RewriterConfig(
                  disable_meta_optimizer=True))),
    )

  # Initializes model parameters.
  params = dict(
      steps_per_epoch=FLAGS.num_train_images / FLAGS.train_batch_size,
      use_bfloat16=FLAGS.use_bfloat16,
      batch_size=FLAGS.train_batch_size
      )
  """mnasnet_est = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=mnasnet_model_fn,
      config=config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      export_to_tpu=FLAGS.export_to_tpu,
      params=params)"""
  mnasnet_est = tf.estimator.Estimator(
      model_fn=mnasnet_model_fn,
      config=config,
      params=params
  )

  tf.logging.info('Using dataset: %s', FLAGS.data_dir)
  imagenet_train, imagenet_eval = [
      imagenet_input.ImageNetInput(
          is_training=is_training,
          data_dir=FLAGS.data_dir,
          transpose_input=FLAGS.transpose_input,
          cache=FLAGS.use_cache and is_training,
          image_size=FLAGS.input_image_size,
          num_parallel_calls=FLAGS.num_parallel_calls,
          use_bfloat16=FLAGS.use_bfloat16) for is_training in [True, False]
  ]

  if FLAGS.mode == 'eval':
    eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
    # Run evaluation when there's a new checkpoint
    for ckpt in evaluation.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info('Starting to evaluate.')
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = mnasnet_est.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Eval results: %s. Elapsed seconds: %d', eval_results,
                        elapsed_time)

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split('-')[1])
        if current_step >= FLAGS.train_steps:
          tf.logging.info('Evaluation finished after training step %d',
                          current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info('Checkpoint %s no longer exists, skipping checkpoint',
                        ckpt)

    if FLAGS.export_dir:
      export(mnasnet_est, FLAGS.export_dir, FLAGS.post_quantize)
  else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = load_global_step_from_checkpoint_dir(  # pylint: disable=protected-access
        FLAGS.model_dir)

    tf.logging.info(
        'Training for %d steps (%.2f epochs in total). Current'
        ' step %d.', FLAGS.train_steps,
        FLAGS.train_steps / params['steps_per_epoch'], current_step)

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == 'train':
      hooks = []
      hook = tf.train.ProfilerHook(save_steps=16, output_dir=FLAGS.model_dir)
      hooks.append(hook)
      """if FLAGS.use_async_checkpointing:
        hooks.append(
            async_checkpoint.AsyncCheckpointSaverHook(
                checkpoint_dir=FLAGS.model_dir,
                save_steps=max(100, FLAGS.iterations_per_loop)))"""
      #with tf.contrib.tfprof.ProfileContext(save_steps=10,output_dir=FLAGS.model_dir) as pctx:
      mnasnet_est.train(
          input_fn=imagenet_train.input_fn,
          max_steps=FLAGS.train_steps,
          hooks=hooks)

    else:
      assert FLAGS.mode == 'train_and_eval'
      while current_step < FLAGS.train_steps:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              FLAGS.train_steps)
        mnasnet_est.train(
            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                        next_checkpoint, int(time.time() - start_timestamp))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = mnasnet_est.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=FLAGS.num_eval_images // FLAGS.eval_batch_size)
        tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                        eval_results)

      elapsed_time = int(time.time() - start_timestamp)
      tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                      FLAGS.train_steps, elapsed_time)
      if FLAGS.export_dir:
        export(mnasnet_est, FLAGS.export_dir, FLAGS.post_quantize)
示例#16
0
def main(unused_argv):

    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    batch_size = PER_CORE_BATCH_SIZE * FLAGS.num_cores
    steps_per_epoch = FLAGS.steps_per_epoch or (int(
        APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
    steps_per_eval = IMAGENET_VALIDATION_IMAGES // batch_size

    logging.info('Saving checkpoints at %s', model_dir)

    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data,
                                                  batch_size=batch_size,
                                                  use_bfloat16=_USE_BFLOAT16)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data,
                                                 batch_size=batch_size,
                                                 use_bfloat16=_USE_BFLOAT16)

    train_iterator = strategy.experimental_distribute_dataset(
        imagenet_train.input_fn()).make_initializable_iterator()
    test_iterator = strategy.experimental_distribute_dataset(
        imagenet_eval.input_fn()).make_initializable_iterator()

    with strategy.scope():
        logging.info('Building Keras ResNet-50 model')
        model = resnet_model.ResNet50(num_classes=NUM_CLASSES)
        optimizer = tf.keras.optimizers.SGD(learning_rate=_BASE_LEARNING_RATE,
                                            momentum=0.9,
                                            nesterov=True)
        training_loss = tf.keras.metrics.Mean('training_loss',
                                              dtype=tf.float32)
        training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'training_accuracy', dtype=tf.float32)
        test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)
        logging.info('Finished building Keras ResNet-50 model')

    def train_step(inputs):
        """Training StepFn."""
        images, labels = inputs
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)

            # Loss calculations.
            #
            # Part 1: Prediciton loss.
            prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                labels, predictions)
            loss1 = tf.reduce_mean(prediction_loss)
            # Part 2: Model weights regularization
            loss2 = tf.reduce_sum(model.losses)

            # Scale the loss given the TPUStrategy will reduce sum all gradients.
            loss = loss1 + loss2
            scaled_loss = loss / strategy.num_replicas_in_sync

        grads = tape.gradient(scaled_loss, model.trainable_variables)
        update_vars = optimizer.apply_gradients(
            zip(grads, model.trainable_variables))
        update_loss = training_loss.update_state(loss)
        update_accuracy = training_accuracy.update_state(labels, predictions)
        with tf.control_dependencies(
            [update_vars, update_loss, update_accuracy]):
            return tf.identity(loss)

    def test_step(inputs):
        """Evaluation StepFn."""
        images, labels = inputs
        predictions = model(images, training=False)
        loss = tf.keras.losses.sparse_categorical_crossentropy(
            labels, predictions)
        loss = tf.reduce_mean(loss)
        update_loss = test_loss.update_state(loss)
        update_accuracy = test_accuracy.update_state(labels, predictions)
        with tf.control_dependencies([update_loss, update_accuracy]):
            return tf.identity(loss)

    dist_train = strategy.experimental_local_results(
        strategy.run(train_step, args=(next(train_iterator), )))
    dist_test = strategy.experimental_local_results(
        strategy.run(test_step, args=(next(test_iterator), )))

    training_loss_result = training_loss.result()
    training_accuracy_result = training_accuracy.result()
    test_loss_result = test_loss.result()
    test_accuracy_result = test_accuracy.result()

    train_iterator_init = train_iterator.initialize()
    test_iterator_init = test_iterator.initialize()

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    cluster_spec = resolver.cluster_spec()
    if cluster_spec:
        config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
    with tf.Session(target=resolver.master(), config=config) as sess:
        all_variables = (tf.global_variables() + training_loss.variables +
                         training_accuracy.variables + test_loss.variables +
                         test_accuracy.variables)
        sess.run([v.initializer for v in all_variables])
        sess.run(train_iterator_init)

        for epoch in range(0, FLAGS.num_epochs):
            logging.info('Starting to run epoch: %s', epoch)
            for step in range(steps_per_epoch):
                learning_rate = compute_learning_rate(epoch + 1 +
                                                      (float(step) /
                                                       steps_per_epoch))
                sess.run(optimizer.lr.assign(learning_rate))
                if step % 20 == 0:
                    logging.info('Learning rate at step %s in epoch %s is %s',
                                 step, epoch, learning_rate)
                sess.run(dist_train)
                if step % 20 == 0:
                    logging.info(
                        'Training loss: %s, accuracy: %s%%',
                        round(sess.run(training_loss_result), 4),
                        round(sess.run(training_accuracy_result) * 100, 2))
                training_loss.reset_states()
                training_accuracy.reset_states()

            sess.run(test_iterator_init)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                sess.run(dist_test)
                if step % 20 == 0:
                    logging.info(
                        'Test loss: %s, accuracy: %s%%',
                        round(sess.run(test_loss_result), 4),
                        round(sess.run(test_accuracy_result) * 100, 2))
                test_loss.reset_states()
                test_accuracy.reset_states()
示例#17
0
def main(unused_argv):
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data_dir)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data_dir)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                   batches_per_epoch, current_step))
    start_timestamp = time.time()
    while current_step < FLAGS.train_steps:
        # Train for up to steps_per_eval number of steps. At the end of training, a
        # checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              FLAGS.train_steps)
        resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                max_steps=next_checkpoint)
        current_step = next_checkpoint

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info(
            'Finished training up to step %d. Elapsed seconds %d.' %
            (current_step, elapsed_time))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images may
        # be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
        tf.logging.info('Eval results: %s' % eval_results)

    if FLAGS.export_dir is not None:
        # The guide to serve a exported TensorFlow model is at:
        #    https://www.tensorflow.org/serving/serving_basic
        tf.logging.info('Starting to export model.')
        resnet_classifier.export_savedmodel(
            export_dir_base=FLAGS.export_dir,
            serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
示例#18
0
def main(unused_argv):
    params = params_dict.ParamsDict(mnasnet_config.MNASNET_CFG,
                                    mnasnet_config.MNASNET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)

    params = flags_to_params.override_params_from_input_flags(params, FLAGS)

    additional_params = {
        'steps_per_epoch': params.num_train_images / params.train_batch_size,
        'quantized_training': FLAGS.quantized_training,
    }

    params = params_dict.override_params_dict(params,
                                              additional_params,
                                              is_strict=False)

    params.validate()
    params.lock()

    if FLAGS.tpu or params.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    if params.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, params.iterations_per_loop)
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=params.iterations_per_loop,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    # Validates Flags.
    if params.precision == 'bfloat16' and params.use_keras:
        raise ValueError(
            'Keras layers do not have full support to bfloat16 activation training.'
            ' You have set precision as %s and use_keras as %s' %
            (params.precision, params.use_keras))

    # Initializes model parameters.
    mnasnet_est = tf.contrib.tpu.TPUEstimator(
        use_tpu=params.use_tpu,
        model_fn=mnasnet_model_fn,
        config=config,
        train_batch_size=params.train_batch_size,
        eval_batch_size=params.eval_batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params.as_dict())

    if FLAGS.mode == 'export_only':
        export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
        return

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=False,
                transpose_input=params.transpose_input,
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=params.transpose_input,
                cache=params.use_cache and is_training,
                image_size=params.input_image_size,
                num_parallel_calls=params.num_parallel_calls,
                use_bfloat16=(params.precision == 'bfloat16'))
            for is_training in [True, False]
        ]

    if FLAGS.mode == 'eval':
        eval_steps = params.num_eval_images // params.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= params.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

        if FLAGS.export_dir:
            export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(  # pylint: disable=protected-access
            FLAGS.model_dir)

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', params.train_steps,
            params.train_steps / params.steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if params.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, params.iterations_per_loop)))
            mnasnet_est.train(input_fn=imagenet_train.input_fn,
                              max_steps=params.train_steps,
                              hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < params.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      params.train_steps)
                mnasnet_est.train(input_fn=imagenet_train.input_fn,
                                  max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=params.num_eval_images // params.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                params.train_steps, elapsed_time)
            if FLAGS.export_dir:
                export(mnasnet_est, FLAGS.export_dir, params,
                       FLAGS.post_quantize)
示例#19
0
def main(unused_argv):
  assert FLAGS.data is not None, 'Provide training data path via --data.'

  batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE

  training_steps_per_epoch = FLAGS.steps_per_epoch or (
      int(APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
  validation_steps = int(IMAGENET_VALIDATION_IMAGES // batch_size)

  model_dir = FLAGS.model_dir
  logging.info('Saving tensorboard summaries at %s', model_dir)

  logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local')
  logging.info('Use bfloat16: %s.', USE_BFLOAT16)
  logging.info('Use global batch size: %s.', batch_size)
  logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)
  logging.info('Training model using data in directory "%s".', FLAGS.data)

  logging.info('Building Keras ResNet-50 model')
  # tpu_model = resnet_model.ResNet50(num_classes=NUM_CLASSES)

  base_model = tf.keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet', input_shape=(224,224,3), classes=NUM_CLASSES)
  print(base_model)
  for layer in base_model.layers:
      layer.trainable = False

  x=base_model.output
  x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x)
  x = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
  tpu_model = Model(base_model.input, x)

  # tpu_model.load_weights("model/model.ckpt-112603")

  strategy=tf.contrib.tpu.TPUDistributionStrategy(tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR']))
  tpu_model = tf.contrib.tpu.keras_to_tpu_model(tpu_model,strategy)
  
  logging.info('Compiling model.')
  metrics = ['sparse_categorical_accuracy']

  if FLAGS.eval_top_5_accuracy:
    metrics.append(sparse_top_k_categorical_accuracy)

  tpu_model.compile(
        optimizer=optimizers.SGD(lr=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
        loss='sparse_categorical_crossentropy',
        metrics=metrics)

  imagenet_train = imagenet_input.ImageNetInput(
      is_training=True, data_dir=FLAGS.data, batch_size=batch_size,
      use_bfloat16=USE_BFLOAT16)
  imagenet_eval = imagenet_input.ImageNetInput(
      is_training=False, data_dir=FLAGS.data, batch_size=batch_size,
      use_bfloat16=USE_BFLOAT16)

  lr_schedule_cb = LearningRateBatchScheduler(
      schedule=learning_rate_schedule_wrapper(training_steps_per_epoch))
  tensorboard_cb = tf.keras.callbacks.TensorBoard(
      log_dir=model_dir)

  # checkpoint_path = "model/model.ckpt-112603"
  #
  # # Create checkpoint callback
  # cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
  #                                                  save_weights_only=True,
  #                                                  verbose=1)

  training_callbacks = [lr_schedule_cb, tensorboard_cb]

  tpu_model.fit(
      imagenet_train.input_fn().make_one_shot_iterator(),
      epochs=EPOCHS,
      steps_per_epoch=training_steps_per_epoch,
      callbacks=training_callbacks,
      validation_data=imagenet_eval.input_fn().make_one_shot_iterator(),
      validation_steps=validation_steps)


  model_saving_utils.save_model(tpu_model, model_dir, WEIGHTS_TXT)
示例#20
0
文件: resnet50.py 项目: StefanL19/tpu
def main(argv):
    logging.info('Building Keras ResNet-50 model.')
    model = tf.keras.applications.resnet50.ResNet50(include_top=True,
                                                    weights=None,
                                                    input_tensor=None,
                                                    input_shape=None,
                                                    pooling=None,
                                                    classes=NUM_CLASSES)

    num_cores = 8
    batch_size = PER_CORE_BATCH_SIZE * num_cores

    if FLAGS.tpu is not None:
        logging.info('Converting from CPU to TPU model.')
        resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver)
        model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
        session_master = resolver.master()
    else:
        session_master = ''

    logging.info('Compiling model.')
    model.compile(
        optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
        loss='sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy'])

    if FLAGS.data is None:
        training_images = np.random.randn(batch_size, IMAGE_SIZE, IMAGE_SIZE,
                                          3).astype(np.float32)
        training_labels = np.random.randint(NUM_CLASSES,
                                            size=batch_size,
                                            dtype=np.int32)
        logging.info('Training model using synthetica data.')
        num_epochs = 100  # TPUs are very fast when running a single step per epoch!
        model.fit(training_images,
                  training_labels,
                  epochs=num_epochs,
                  batch_size=batch_size)
        logging.info('Evaluating the model on synthetic data.')
        model.evaluate(training_images, training_labels, verbose=0)
    else:

        imagenet_train = imagenet_input.ImageNetInput(
            is_training=True,
            data_dir=FLAGS.data,
            per_core_batch_size=PER_CORE_BATCH_SIZE)
        logging.info('Training model using real data in directory "%s".',
                     FLAGS.data)
        num_epochs = 90  # Standard imagenet training regime.
        model.fit(imagenet_train.input_fn,
                  epochs=num_epochs,
                  steps_per_epoch=int(APPROX_IMAGENET_TRAINING_IMAGES /
                                      batch_size))

        logging.info('Evaluating the model on the validation dataset.')
        # Direct evaluation with datasets is coming in TF 1.11.  For now,
        # we can perform evaluation using a standard Python generator.
        imagenet_eval = imagenet_input.ImageNetInput(
            is_training=False,
            data_dir=FLAGS.data,
            # In normal execution, our dataset would generate data for each TPU
            # core.  In this case, because we are feeding in from a Keras generator,
            # we want to build a single batch for all of the cores, which will then
            # be split for us.
            per_core_batch_size=batch_size)
        score = model.evaluate_generator(
            imagenet_eval.evaluation_generator(session_master),
            steps=int(APPROX_IMAGENET_TEST_IMAGES // batch_size),
            verbose=1)
        logging.info('Evaluation score %s', score)
示例#21
0
def main(argv):
    logging.info('Building Keras ResNet-50 model.')
    model = tf.keras.applications.resnet50.ResNet50(include_top=True,
                                                    weights=None,
                                                    input_tensor=None,
                                                    input_shape=None,
                                                    pooling=None,
                                                    classes=NUM_CLASSES)

    num_cores = 8
    batch_size = PER_CORE_BATCH_SIZE * num_cores

    if FLAGS.use_tpu:
        logging.info('Converting from CPU to TPU model.')
        resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        strategy = tf.contrib.tpu.TPUDistributionStrategy(resolver)
        model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
        session_master = resolver.master()
    else:
        session_master = ''

    logging.info('Compiling model.')
    model.compile(optimizer=tf.keras.optimizers.SGD(lr=BASE_LEARNING_RATE,
                                                    momentum=0.9,
                                                    nesterov=True),
                  loss='sparse_categorical_crossentropy',
                  metrics=['sparse_categorical_accuracy'])

    if FLAGS.data is None:
        training_images = np.random.randn(batch_size, IMAGE_SIZE, IMAGE_SIZE,
                                          3).astype(np.float32)
        training_labels = np.random.randint(NUM_CLASSES,
                                            size=batch_size,
                                            dtype=np.int32)
        logging.info('Training model using synthetica data.')
        model.fit(training_images,
                  training_labels,
                  epochs=EPOCHS,
                  batch_size=batch_size)
        logging.info('Evaluating the model on synthetic data.')
        model.evaluate(training_images, training_labels, verbose=0)
    else:
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data,
                per_core_batch_size=PER_CORE_BATCH_SIZE)
            for is_training in [True, False]
        ]
        logging.info('Training model using real data in directory "%s".',
                     FLAGS.data)
        model.fit(imagenet_train.input_fn,
                  epochs=EPOCHS,
                  steps_per_epoch=int(APPROX_IMAGENET_TRAINING_IMAGES /
                                      batch_size))

        if HAS_H5PY:
            logging.info('Save weights into %s', WEIGHTS_TXT)
            model.save_weights(WEIGHTS_TXT, overwrite=True)

        logging.info('Evaluating the model on the validation dataset.')
        score = model.evaluate(imagenet_eval.input_fn,
                               steps=int(APPROX_IMAGENET_TEST_IMAGES //
                                         batch_size),
                               verbose=1)
        print('Evaluation score', score)
示例#22
0
文件: resnet_main.py 项目: tgrel/tpu
def main(unused_argv):
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        keep_checkpoint_max=None,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data_dir,
        num_parallel_calls=FLAGS.num_parallel_calls)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data_dir,
        num_parallel_calls=FLAGS.num_parallel_calls)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    batches_per_epoch = NUM_TRAIN_IMAGES // FLAGS.train_batch_size
    start_timestamp = time.time()
    current_epoch = current_step // FLAGS.train_batch_size
    results = []
    while current_epoch < 95:
        next_checkpoint = (current_epoch + 1) * batches_per_epoch
        resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                max_steps=next_checkpoint)
        current_epoch += 1

        tf.logging.info(
            'Finished training up to step %d. Elapsed seconds %d.' %
            (next_checkpoint, int(time.time() - start_timestamp)))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images may
        # be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
        tf.logging.info('Eval results: %s' % eval_results)

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Finished epoch %s at %s time' %
                        (current_epoch, elapsed_time))
        results.append([
            current_epoch,
            elapsed_time / 3600.0,
            '{0:.2f}'.format(eval_results['Top-1 accuracy'] * 100),
            '{0:.2f}'.format(eval_results['Top-5 accuracy'] * 100),
        ])

    with tf.gfile.GFile(FLAGS.model_dir + '/epoch_results.tsv',
                        'wb') as tsv_file:
        writer = csv.writer(tsv_file, delimiter='\t')
        writer.writerow(['epoch', 'hours', 'top1Accuracy', 'top5Accuracy'])
        writer.writerows(results)

    if FLAGS.export_dir is not None:
        # The guide to serve a exported TensorFlow model is at:
        #    https://www.tensorflow.org/serving/serving_basic
        tf.logging.info('Starting to export model.')
        resnet_classifier.export_savedmodel(
            export_dir_base=FLAGS.export_dir,
            serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
示例#23
0
def main(unused_argv):
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if FLAGS.master is None and FLAGS.tpu_name is None:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master is not None:
            if FLAGS.tpu_name is not None:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
            tpu_grpc_url = tpu_cluster_resolver.get_master()
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.iterations_per_loop,
        keep_checkpoint_max=5,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            PER_HOST_V2))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data_dir,
        num_parallel_calls=FLAGS.num_parallel_calls,
        use_transpose=FLAGS.use_transpose)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data_dir,
        num_parallel_calls=FLAGS.num_parallel_calls,
        use_transpose=FLAGS.use_transpose)

    current_step = estimator._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    steps_per_epoch = NUM_TRAIN_IMAGES // FLAGS.train_batch_size
    start_timestamp = time.time()
    current_epoch = current_step // steps_per_epoch

    if FLAGS.mode == 'train':
        resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                max_steps=FLAGS.train_steps)
        training_time = time.time() - start_timestamp
        tf.logging.info('Finished training in %d seconds' % training_time)

        with tf.gfile.GFile(FLAGS.model_dir + '/total_time_%s.txt' % training_time, 'w') as f:  # pylint: disable=line-too-long
            f.write('Total training time was %s seconds' % training_time)

    elif FLAGS.mode == 'eval':
        results = []

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(FLAGS.model_dir):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                                (eval_results, elapsed_time))

                current_step = int(os.path.basename(ckpt).split('-')[1])
                current_epoch = current_step // steps_per_epoch
                results.append([
                    current_epoch,
                    '{0:.2f}'.format(eval_results['top_1_accuracy'] * 100),
                    '{0:.2f}'.format(eval_results['top_5_accuracy'] * 100),
                ])

                # Terminate eval job when final checkpoint is reached
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)

        with tf.gfile.GFile(FLAGS.model_dir + '/epoch_results_eval.tsv', 'wb') as tsv_file:  # pylint: disable=line-too-long
            writer = csv.writer(tsv_file, delimiter='\t')
            writer.writerow(['epoch', 'top1Accuracy', 'top5Accuracy'])
            writer.writerows(results)

    elif FLAGS.mode == 'train_and_eval':
        results = []
        while current_epoch < 95:
            next_checkpoint = (current_epoch + 1) * steps_per_epoch
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=next_checkpoint)
            current_epoch += 1

            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.' %
                (next_checkpoint, int(time.time() - start_timestamp)))

            # Evaluate the model on the most recent model in --model_dir.
            # Since evaluation happens in batches of --eval_batch_size, some images
            # may be excluded modulo the batch size. As long as the batch size is
            # consistent, the evaluated images are also consistent.
            tf.logging.info('Starting to evaluate.')
            eval_results = resnet_classifier.evaluate(
                input_fn=imagenet_eval.input_fn,
                steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
            tf.logging.info('Eval results: %s' % eval_results)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info('Finished epoch %s at %s time' %
                            (current_epoch, elapsed_time))
            results.append([
                current_epoch,
                elapsed_time / 3600.0,
                '{0:.2f}'.format(eval_results['top_1_accuracy'] * 100),
                '{0:.2f}'.format(eval_results['top_5_accuracy'] * 100),
            ])

        with tf.gfile.GFile(FLAGS.model_dir + '/epoch_results_train_eval.tsv', 'wb') as tsv_file:  # pylint: disable=line-too-long
            writer = csv.writer(tsv_file, delimiter='\t')
            writer.writerow(['epoch', 'hours', 'top1Accuracy', 'top5Accuracy'])
            writer.writerows(results)
    else:
        tf.logging.info('Mode not found.')

    if FLAGS.export_dir is not None:
        # The guide to serve a exported TensorFlow model is at:
        #    https://www.tensorflow.org/serving/serving_basic
        tf.logging.info('Starting to export model.')
        resnet_classifier.export_savedmodel(
            export_dir_base=FLAGS.export_dir,
            serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
示例#24
0
def main(unused_argv):
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.iterations_per_loop,
        keep_checkpoint_max=None,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data_dir,
        use_transpose=FLAGS.use_transpose)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data_dir,
        use_transpose=FLAGS.use_transpose)

    if FLAGS.use_fast_lr:
        resnet_main.LR_SCHEDULE = [  # (multiplier, epoch to start) tuples
            (1.0, 4), (0.1, 21), (0.01, 35), (0.001, 43)
        ]
        imagenet_train_small = imagenet_input.ImageNetInput(
            is_training=True,
            image_size=128,
            data_dir=FLAGS.data_dir_small,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_transpose=FLAGS.use_transpose,
            cache=True)
        imagenet_eval_small = imagenet_input.ImageNetInput(
            is_training=False,
            image_size=128,
            data_dir=FLAGS.data_dir_small,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_transpose=FLAGS.use_transpose,
            cache=True)
        imagenet_train_large = imagenet_input.ImageNetInput(
            is_training=True,
            image_size=288,
            data_dir=FLAGS.data_dir,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_transpose=FLAGS.use_transpose)
        imagenet_eval_large = imagenet_input.ImageNetInput(
            is_training=False,
            image_size=288,
            data_dir=FLAGS.data_dir,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_transpose=FLAGS.use_transpose)

    resnet_classifier = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_main.resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    if FLAGS.mode == 'train':
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = resnet_main.NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                       batches_per_epoch, current_step))

        start_timestamp = time.time(
        )  # This time will include compilation time

        # Write a dummy file at the start of training so that we can measure the
        # runtime at each checkpoint from the file write time.
        tf.gfile.MkDir(FLAGS.model_dir)
        if not tf.gfile.Exists(os.path.join(FLAGS.model_dir, 'START')):
            with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'START'),
                                'w') as f:
                f.write(str(start_timestamp))

        if FLAGS.use_fast_lr:
            resnet_classifier.train(input_fn=imagenet_train_small.input_fn,
                                    max_steps=18 * 1251)
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=41 * 1251)
            resnet_classifier.train(input_fn=imagenet_train_large.input_fn,
                                    max_steps=min(50 * 1251,
                                                  FLAGS.train_steps))
        else:
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=FLAGS.train_steps)

    else:
        assert FLAGS.mode == 'eval'

        start_timestamp = tf.gfile.Stat(os.path.join(FLAGS.model_dir,
                                                     'START')).mtime_nsec
        results = []
        eval_steps = resnet_main.NUM_EVAL_IMAGES // FLAGS.eval_batch_size

        ckpt_steps = set()
        all_files = tf.gfile.ListDirectory(FLAGS.model_dir)
        for f in all_files:
            mat = re.match(CKPT_PATTERN, f)
            if mat is not None:
                ckpt_steps.add(int(mat.group('gs')))
        ckpt_steps = sorted(list(ckpt_steps))
        tf.logging.info('Steps to be evaluated: %s' % str(ckpt_steps))

        for step in ckpt_steps:
            ckpt = os.path.join(FLAGS.model_dir, 'model.ckpt-%d' % step)

            batches_per_epoch = resnet_main.NUM_TRAIN_IMAGES // FLAGS.train_batch_size
            current_epoch = step // batches_per_epoch

            if FLAGS.use_fast_lr:
                if current_epoch < 18:
                    eval_input_fn = imagenet_eval_small.input_fn
                if current_epoch >= 18 and current_epoch < 41:
                    eval_input_fn = imagenet_eval.input_fn
                if current_epoch >= 41:  # 41:
                    eval_input_fn = imagenet_eval_large.input_fn
            else:
                eval_input_fn = imagenet_eval.input_fn

            end_timestamp = tf.gfile.Stat(ckpt + '.index').mtime_nsec
            elapsed_hours = (end_timestamp - start_timestamp) / (1e9 * 3600.0)

            tf.logging.info('Starting to evaluate.')
            eval_start = time.time()  # This time will include compilation time
            eval_results = resnet_classifier.evaluate(input_fn=eval_input_fn,
                                                      steps=eval_steps,
                                                      checkpoint_path=ckpt)
            eval_time = int(time.time() - eval_start)
            tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                            (eval_results, eval_time))
            results.append([
                current_epoch,
                elapsed_hours,
                '%.2f' % (eval_results['top_1_accuracy'] * 100),
                '%.2f' % (eval_results['top_5_accuracy'] * 100),
            ])

            time.sleep(60)

        with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.tsv'), 'wb') as tsv_file:  # pylint: disable=line-too-long
            writer = csv.writer(tsv_file, delimiter='\t')
            writer.writerow(['epoch', 'hours', 'top1Accuracy', 'top5Accuracy'])
            writer.writerows(results)
示例#25
0
def main(unused_argv):
    tpu_grpc_url = None
    tpu_cluster_resolver = None
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if not FLAGS.master and not FLAGS.tpu_name:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master:
            if FLAGS.tpu_name:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.iterations_per_loop,
        keep_checkpoint_max=None,
        cluster=tpu_cluster_resolver,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores,
            per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_main.resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(
        is_training=True,
        data_dir=FLAGS.data_dir,
        transpose_input=FLAGS.transpose_input)
    imagenet_eval = imagenet_input.ImageNetInput(
        is_training=False,
        data_dir=FLAGS.data_dir,
        transpose_input=FLAGS.transpose_input)

    if FLAGS.mode == 'train':
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = resnet_main.NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                       batches_per_epoch, current_step))

        start_timestamp = time.time(
        )  # This time will include compilation time

        # Write a dummy file at the start of training so that we can measure the
        # runtime at each checkpoint from the file write time.
        tf.gfile.MkDir(FLAGS.model_dir)
        if not tf.gfile.Exists(os.path.join(FLAGS.model_dir, 'START')):
            with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'START'),
                                'w') as f:
                f.write(str(start_timestamp))

        resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                max_steps=FLAGS.train_steps)

    else:
        assert FLAGS.mode == 'eval'

        start_timestamp = tf.gfile.Stat(os.path.join(FLAGS.model_dir,
                                                     'START')).mtime_nsec
        results = []
        eval_steps = resnet_main.NUM_EVAL_IMAGES // FLAGS.eval_batch_size

        ckpt_steps = set()
        all_files = tf.gfile.ListDirectory(FLAGS.model_dir)
        for f in all_files:
            mat = re.match(CKPT_PATTERN, f)
            if mat is not None:
                ckpt_steps.add(int(mat.group('gs')))
        ckpt_steps = sorted(list(ckpt_steps))
        tf.logging.info('Steps to be evaluated: %s' % str(ckpt_steps))

        for step in ckpt_steps:
            ckpt = os.path.join(FLAGS.model_dir, 'model.ckpt-%d' % step)

            batches_per_epoch = resnet_main.NUM_TRAIN_IMAGES // FLAGS.train_batch_size
            current_epoch = step // batches_per_epoch

            end_timestamp = tf.gfile.Stat(ckpt + '.index').mtime_nsec
            elapsed_hours = (end_timestamp - start_timestamp) / (1e9 * 3600.0)

            tf.logging.info('Starting to evaluate.')
            eval_start = time.time()  # This time will include compilation time
            eval_results = resnet_classifier.evaluate(
                input_fn=imagenet_eval.input_fn,
                steps=eval_steps,
                checkpoint_path=ckpt)
            eval_time = int(time.time() - eval_start)
            tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                            (eval_results, eval_time))
            results.append([
                current_epoch,
                elapsed_hours,
                '%.2f' % (eval_results['top_1_accuracy'] * 100),
                '%.2f' % (eval_results['top_5_accuracy'] * 100),
            ])

            time.sleep(60)

        with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.tsv'), 'wb') as tsv_file:  # pylint: disable=line-too-long
            writer = csv.writer(tsv_file, delimiter='\t')
            writer.writerow(['epoch', 'hours', 'top1Accuracy', 'top5Accuracy'])
            writer.writerows(results)
示例#26
0
def main(unused_argv):
    tf.enable_v2_behavior()
    num_workers = 1
    job_name = 'worker'
    primary_cpu_task = '/job:%s' % job_name

    is_tpu_pod = num_workers > 1
    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    batch_size = PER_CORE_BATCH_SIZE * FLAGS.num_cores
    steps_per_epoch = FLAGS.steps_per_epoch or (int(
        APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
    steps_per_eval = int(1.0 *
                         math.ceil(IMAGENET_VALIDATION_IMAGES / batch_size))

    logging.info('Saving checkpoints at %s', model_dir)

    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        tpu=FLAGS.tpu, job_name=job_name)
    tf.config.experimental_connect_to_host(resolver.master())  # pylint: disable=line-too-long
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

    with tf.device(primary_cpu_task):
        # TODO(b/130307853): In TPU Pod, we have to use
        # `strategy.experimental_distribute_datasets_from_function` instead of
        # `strategy.experimental_distribute_dataset` because dataset cannot be
        # cloned in eager mode. And when using
        # `strategy.experimental_distribute_datasets_from_function`, we should use
        # per core batch size instead of global batch size, because no re-batch is
        # happening in this case.
        if is_tpu_pod:
            imagenet_train = imagenet_input.ImageNetInput(
                is_training=True,
                data_dir=FLAGS.data,
                batch_size=PER_CORE_BATCH_SIZE,
                use_bfloat16=_USE_BFLOAT16)
            imagenet_eval = imagenet_input.ImageNetInput(
                is_training=False,
                data_dir=FLAGS.data,
                batch_size=PER_CORE_BATCH_SIZE,
                use_bfloat16=_USE_BFLOAT16)
            train_dataset = strategy.experimental_distribute_datasets_from_function(
                imagenet_train.input_fn)
            test_dataset = strategy.experimental_distribute_datasets_from_function(
                imagenet_eval.input_fn)
        else:
            imagenet_train = imagenet_input.ImageNetInput(
                is_training=True,
                data_dir=FLAGS.data,
                batch_size=batch_size,
                use_bfloat16=_USE_BFLOAT16)
            imagenet_eval = imagenet_input.ImageNetInput(
                is_training=False,
                data_dir=FLAGS.data,
                batch_size=batch_size,
                use_bfloat16=_USE_BFLOAT16)
            train_dataset = strategy.experimental_distribute_dataset(
                imagenet_train.input_fn())
            test_dataset = strategy.experimental_distribute_dataset(
                imagenet_eval.input_fn())

        with strategy.scope():
            logging.info('Building Keras ResNet-50 model')
            model = resnet_model.ResNet50(num_classes=NUM_CLASSES)
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=ResnetLearningRateSchedule(
                    steps_per_epoch, _BASE_LEARNING_RATE),
                momentum=0.9,
                nesterov=True)
            training_loss = tf.keras.metrics.Mean('training_loss',
                                                  dtype=tf.float32)
            training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                'training_accuracy', dtype=tf.float32)
            test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
            test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                'test_accuracy', dtype=tf.float32)
            logging.info('Finished building Keras ResNet-50 model')

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(model_dir)
        initial_epoch = 0
        if latest_checkpoint:
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

        # Create summary writers
        train_summary_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'summaries/train'))
        test_summary_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'summaries/test'))

        @tf.function
        def train_step(iterator):
            """Training StepFn."""
            def step_fn(inputs):
                """Per-Replica StepFn."""
                images, labels = inputs
                with tf.GradientTape() as tape:
                    logits = model(images, training=True)

                    # Loss calculations.
                    #
                    # Part 1: Prediction loss.
                    prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits)
                    loss1 = tf.reduce_mean(prediction_loss)
                    # Part 2: Model weights regularization
                    loss2 = tf.reduce_sum(model.losses)

                    # Scale the loss given the TPUStrategy will reduce sum all gradients.
                    loss = loss1 + loss2
                    loss = loss / strategy.num_replicas_in_sync

                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))
                training_loss.update_state(loss)
                training_accuracy.update_state(labels, logits)

            strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        @tf.function
        def test_step(iterator):
            """Evaluation StepFn."""
            def step_fn(inputs):
                images, labels = inputs
                logits = model(images, training=False)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, logits)
                loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync
                test_loss.update_state(loss)
                test_accuracy.update_state(labels, logits)

            strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        train_iterator = iter(train_dataset)
        for epoch in range(initial_epoch, FLAGS.num_epochs):
            logging.info('Starting to run epoch: %s', epoch)
            with train_summary_writer.as_default():
                for step in range(steps_per_epoch):
                    if step % 20 == 0:
                        logging.info('Running step %s in epoch %s', step,
                                     epoch)
                    train_step(train_iterator)
                tf.summary.scalar('loss',
                                  training_loss.result(),
                                  step=optimizer.iterations)
                tf.summary.scalar('accuracy',
                                  training_accuracy.result(),
                                  step=optimizer.iterations)
                logging.info('Training loss: %s, accuracy: %s%%',
                             round(training_loss.result(), 4),
                             round(training_accuracy.result() * 100, 2))
                training_loss.reset_states()
                training_accuracy.reset_states()

            with test_summary_writer.as_default():
                test_iterator = iter(test_dataset)
                for step in range(steps_per_eval):
                    if step % 20 == 0:
                        logging.info(
                            'Starting to run eval step %s of epoch: %s', step,
                            epoch)
                    test_step(test_iterator)
                tf.summary.scalar('loss',
                                  test_loss.result(),
                                  step=optimizer.iterations)
                tf.summary.scalar('accuracy',
                                  test_accuracy.result(),
                                  step=optimizer.iterations)
                logging.info('Test loss: %s, accuracy: %s%%',
                             round(test_loss.result(), 4),
                             round(test_accuracy.result() * 100, 2))
                test_loss.reset_states()
                test_accuracy.reset_states()

            checkpoint_name = checkpoint.save(
                os.path.join(model_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
示例#27
0
文件: resnet_main.py 项目: vishh/tpu
def main(unused_argv):
    tpu_grpc_url = None
    tpu_cluster_resolver = None
    if FLAGS.use_tpu:
        # Determine the gRPC URL of the TPU device to use
        if not FLAGS.master and not FLAGS.tpu_name:
            raise RuntimeError(
                'You must specify either --master or --tpu_name.')

        if FLAGS.master:
            if FLAGS.tpu_name:
                tf.logging.warn(
                    'Both --master and --tpu_name are set. Ignoring'
                    ' --tpu_name and using --master.')
            tpu_grpc_url = FLAGS.master
        else:
            tpu_cluster_resolver = (
                tf.contrib.cluster_resolver.TPUClusterResolver(
                    FLAGS.tpu_name,
                    zone=FLAGS.tpu_zone,
                    project=FLAGS.gcp_project))
    else:
        # URL is unused if running locally without TPU
        tpu_grpc_url = None

    config = tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.model_dir,
        cluster=tpu_cluster_resolver,
        tpu_config=tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_cores))

    resnet_classifier = tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=resnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data_dir)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data_dir)

    if FLAGS.mode == 'eval':
        eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(FLAGS.model_dir):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
                                (eval_results, elapsed_time))

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                        ' step %d.' % (FLAGS.train_steps, FLAGS.train_steps /
                                       batches_per_epoch, current_step))

        start_timestamp = time.time(
        )  # This time will include compilation time
        if FLAGS.mode == 'train':
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=FLAGS.train_steps)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=next_checkpoint)
                current_step = next_checkpoint

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be consistently excluded modulo the batch size.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=NUM_EVAL_IMAGES // FLAGS.eval_batch_size)
                tf.logging.info('Eval results: %s' % eval_results)

        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info(
            'Finished training up to step %d. Elapsed seconds %d.' %
            (FLAGS.train_steps, elapsed_time))

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            resnet_classifier.export_savedmodel(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
示例#28
0
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_cores,
          per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2))  # pylint: disable=line-too-long

  resnet_classifier = tpu_estimator.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=resnet_main.resnet_model_fn,
      config=config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size)

  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  imagenet_train = imagenet_input.ImageNetInput(
      is_training=True,
      data_dir=FLAGS.data_dir,
      transpose_input=FLAGS.transpose_input)
  imagenet_eval = imagenet_input.ImageNetInput(
      is_training=False,
      data_dir=FLAGS.data_dir,
      transpose_input=FLAGS.transpose_input)

  if FLAGS.mode == 'train':
    current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    batches_per_epoch = resnet_main.NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.' % (FLAGS.train_steps,
                                   FLAGS.train_steps / batches_per_epoch,
                                   current_step))

    start_timestamp = time.time()  # This time will include compilation time
示例#29
0
def main(unused_argv):
  # tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
  #     FLAGS.tpu if (FLAGS.tpu or FLAGS.use_tpu) else '',
  #     zone=FLAGS.tpu_zone,
  #     project=FLAGS.gcp_project)

  if FLAGS.use_async_checkpointing:
    save_checkpoints_steps = None
  else:
    save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)

  NUM_GPUS = len(get_available_gpus())
  distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)
  gpu_options = tf.GPUOptions(allow_growth=True)

  # config = tf.contrib.tpu.RunConfig(
  #     cluster=tpu_cluster_resolver,
  #     model_dir=FLAGS.model_dir,
  #     save_checkpoints_steps=save_checkpoints_steps,
  #     log_step_count_steps=FLAGS.log_step_count_steps,
  #     session_config=tf.ConfigProto(
  #         graph_options=tf.GraphOptions(
  #             rewrite_options=rewriter_config_pb2.RewriterConfig(
  #                 disable_meta_optimizer=True))),
  #     tpu_config=tf.contrib.tpu.TPUConfig(
  #         iterations_per_loop=FLAGS.iterations_per_loop,
  #         per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
  #         .PER_HOST_V2))  # pylint: disable=line-too-long

  config = tf.estimator.RunConfig(
      # cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      log_step_count_steps=FLAGS.log_step_count_steps,
      session_config=tf.ConfigProto(allow_soft_placement=True,
          graph_options=tf.GraphOptions(
              rewrite_options=rewriter_config_pb2.RewriterConfig(
                  disable_meta_optimizer=True)), gpu_options=gpu_options),
      train_distribute=distribution,
      # tpu_config=tf.contrib.tpu.TPUConfig(
      #     iterations_per_loop=FLAGS.iterations_per_loop,
      #     per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
      #     .PER_HOST_V2)
  )
  # Initializes model parameters.
  # params = dict(steps_per_epoch=FLAGS.num_train_images / FLAGS.train_batch_size)
  # model_est = tf.estimator.Estimator(
  #     use_tpu=FLAGS.use_tpu,
  #     model_fn=final_model_fn,
  #     config=config,
  #     train_batch_size=FLAGS.train_batch_size,
  #     eval_batch_size=FLAGS.eval_batch_size,
  #     export_to_tpu=FLAGS.export_to_tpu,
  #     params=params)
  params = dict(steps_per_epoch=FLAGS.num_train_images / FLAGS.train_batch_size, batch_size=FLAGS.train_batch_size)
  model_est = tf.estimator.Estimator(
      model_fn=final_model_fn,
      config=config,
      params=params)

  # Input pipelines are slightly different (with regards to shuffling and
  # preprocessing) between training and evaluation.
  if FLAGS.bigtable_instance:
    tf.logging.info('Using Bigtable dataset, table %s', FLAGS.bigtable_table)
    select_train, select_eval = _select_tables_from_flags()
    imagenet_train, imagenet_eval = [imagenet_input.ImageNetBigtableInput(
        is_training=is_training,
        use_bfloat16=False,
        transpose_input=FLAGS.transpose_input,
        selection=selection) for (is_training, selection) in
                                     [(True, select_train),
                                      (False, select_eval)]]
  else:
    if FLAGS.data_dir == FAKE_DATA_DIR:
      tf.logging.info('Using fake dataset.')
    else:
      tf.logging.info('Using dataset: %s', FLAGS.data_dir)
    imagenet_train, imagenet_eval = [
        imagenet_input.ImageNetInput(
            is_training=is_training,
            data_dir=FLAGS.data_dir,
            transpose_input=FLAGS.transpose_input,
            cache=FLAGS.use_cache and is_training,
            image_size=FLAGS.input_image_size,
            num_parallel_calls=FLAGS.num_parallel_calls,
            use_bfloat16=False) for is_training in [True, False]
    ]

  if FLAGS.mode == 'eval':
    eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
    # Run evaluation when there's a new checkpoint
    for ckpt in evaluation.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info('Starting to evaluate.')
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = model_est.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                        eval_results, elapsed_time)

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split('-')[1])
        if current_step >= FLAGS.train_steps:
          tf.logging.info(
              'Evaluation finished after training step %d', current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info(
            'Checkpoint %s no longer exists, skipping checkpoint', ckpt)

    if FLAGS.export_dir:
      export(model_est, FLAGS.export_dir, FLAGS.post_quantize)
  else:   # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

    tf.logging.info(
        'Training for %d steps (%.2f epochs in total). Current'
        ' step %d.', FLAGS.train_steps,
        FLAGS.train_steps / params['steps_per_epoch'], current_step)

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == 'train':
      hooks = []
      if FLAGS.use_async_checkpointing:
        hooks.append(
            async_checkpoint.AsyncCheckpointSaverHook(
                checkpoint_dir=FLAGS.model_dir,
                save_steps=max(100, FLAGS.iterations_per_loop)))
      model_est.train(
          input_fn=imagenet_train.input_fn,
          max_steps=FLAGS.train_steps,
          hooks=hooks)

    else:
      assert FLAGS.mode == 'train_and_eval'
      while current_step < FLAGS.train_steps:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              FLAGS.train_steps)
        model_est.train(
            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                        next_checkpoint, int(time.time() - start_timestamp))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = model_est.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=FLAGS.num_eval_images // FLAGS.eval_batch_size)
        tf.logging.info('Eval results at step %d: %s',
                        next_checkpoint, eval_results)

      elapsed_time = int(time.time() - start_timestamp)
      tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                      FLAGS.train_steps, elapsed_time)
      if FLAGS.export_dir:
        export(model_est, FLAGS.export_dir, FLAGS.post_quantize)
示例#30
0
def main(unused_argv):
    # Mnas optimize - set the proper image data format
    tf.keras.backend.set_image_data_format(FLAGS.data_format)
    # Mnas optimize - optimization flags
    # gpu_thread_count = 2
    # os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
    # os.environ['TF_GPU_THREAD_COUNT'] = str(gpu_thread_count)
    # os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
    # os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
    # enable mixed precision? -> Not much benefits seen yet
    # os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

    # Horovod: initialize Horovod.
    if FLAGS.use_horovod:
        hvd.init()
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu if (FLAGS.tpu or FLAGS.use_tpu) else '',
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)

    if FLAGS.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        if not FLAGS.use_horovod:
            save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
        else:
            save_checkpoints_steps = max(
                100, FLAGS.iterations_per_loop) if hvd.rank() == 0 else None
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    if FLAGS.use_xla:
        config.session_config.graph_options.optimizer_options.global_jit_level = (
            tf.OptimizerOptions.ON_1)

    # Horovod: pin GPU to be used to process local rank (one GPU per process)
    if FLAGS.use_horovod:
        config.session_config.gpu_options.allow_growth = True
        config.session_config.gpu_options.visible_device_list = str(
            hvd.local_rank())

    # Validates Flags.
    if FLAGS.use_bfloat16 and FLAGS.use_keras:
        raise ValueError(
            'Keras layers do not have full support to bfloat16 activation training.'
            ' You have set use_bfloat as %s and use_keras as %s' %
            (FLAGS.use_bfloat16, FLAGS.use_keras))

    # Initializes model parameters.
    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    steps_per_epoch = steps_per_epoch // hvd.size(
    ) if FLAGS.use_horovod else steps_per_epoch
    params = dict(steps_per_epoch=steps_per_epoch,
                  use_bfloat16=FLAGS.use_bfloat16,
                  quantized_training=FLAGS.quantized_training)
    if FLAGS.use_horovod:
        params['hvd'] = True
        params['hvd_curr_host'] = hvd.rank()
        params['hvd_num_hosts'] = hvd.size()
    mnasnet_est = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=mnasnet_model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params)

    # Horovod: BroadcastGlobalVariablesHook broadcasts initial variable states from
    # rank 0 to all other processes. This is necessary to ensure consistent
    # initialization of all workers when training is started with random weights or
    # restored from a checkpoint.
    if FLAGS.use_horovod:
        bcast_hook = hvd.BroadcastGlobalVariablesHook(0)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=False,
                transpose_input=FLAGS.transpose_input,
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                image_size=FLAGS.input_image_size,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=FLAGS.use_bfloat16)
            for is_training in [True, False]
        ]

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

        if FLAGS.export_dir:
            export(mnasnet_est, FLAGS.export_dir, FLAGS.post_quantize)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(  # pylint: disable=protected-access
            FLAGS.model_dir)

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / params['steps_per_epoch'], current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if FLAGS.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, FLAGS.iterations_per_loop)))
            mnasnet_est.train(input_fn=imagenet_train.input_fn,
                              max_steps=FLAGS.train_steps,
                              hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            curr_rank = 0
            if FLAGS.use_horovod:
                curr_rank = hvd.rank()
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                if FLAGS.use_horovod:
                    # try dali pipeline
                    mnasnet_est.train(input_fn=imagenet_train.train_data_fn,
                                      max_steps=next_checkpoint,
                                      hooks=[bcast_hook])
                    # this uses the old tf data pipeline
                    # mnasnet_est.train(
                    #     input_fn=imagenet_train.input_fn, max_steps=next_checkpoint, hooks=[bcast_hook])
                else:
                    mnasnet_est.train(input_fn=imagenet_train.input_fn,
                                      max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d. Hvd rank %d',
                    next_checkpoint, int(time.time() - start_timestamp),
                    curr_rank)

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                eval_on_single_gpu = FLAGS.eval_on_single_gpu
                tf.logging.info('Starting to evaluate.')
                if eval_on_single_gpu:
                    if curr_rank == 0:
                        eval_results = mnasnet_est.evaluate(
                            input_fn=imagenet_eval.train_data_fn,  #input_fn
                            steps=FLAGS.num_eval_images //
                            FLAGS.eval_batch_size)
                        tf.logging.info(
                            'Eval results at step %d: %s. Hvd rank %d',
                            next_checkpoint, eval_results, curr_rank)
                else:
                    eval_results = mnasnet_est.evaluate(
                        input_fn=imagenet_eval.train_data_fn,  #input_fn
                        steps=FLAGS.num_eval_images // FLAGS.eval_batch_size)
                    tf.logging.info('Eval results at step %d: %s. Hvd rank %d',
                                    next_checkpoint, eval_results, curr_rank)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                FLAGS.train_steps, elapsed_time)
            if FLAGS.export_dir:
                export(mnasnet_est, FLAGS.export_dir, FLAGS.post_quantize)