Esempio n. 1
0
def main(unused_argv):
  assert FLAGS.data is not None, 'Provide training data path via --data.'

  batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE
  training_steps_per_epoch = int(APPROX_IMAGENET_TRAINING_IMAGES / batch_size)
  validation_steps = int(IMAGENET_VALIDATION_IMAGES // batch_size)

  model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
  logging.info('Saving tensorboard summaries at %s', model_dir)

  logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local')
  resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
  tf.contrib.distribute.initialize_tpu_system(resolver)
  strategy = tf.contrib.distribute.TPUStrategy(resolver)

  logging.info('Use bfloat16: %s.', USE_BFLOAT16)
  logging.info('Use global batch size: %s.', batch_size)
  logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)
  logging.info('Training model using data in directory "%s".', FLAGS.data)

  with strategy.scope():
    logging.info('Building Keras ResNet-50 model')
    model = resnet_model.ResNet50(num_classes=NUM_CLASSES)

    logging.info('Compiling model.')
    metrics = ['sparse_categorical_accuracy']

    if FLAGS.eval_top_5_accuracy:
      metrics.append(sparse_top_k_categorical_accuracy)

    model.compile(
        optimizer=gradient_descent.SGD(
            learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
        loss='sparse_categorical_crossentropy',
        metrics=metrics)

  imagenet_train = imagenet_input.ImageNetInput(
      is_training=True, data_dir=FLAGS.data, batch_size=batch_size,
      use_bfloat16=USE_BFLOAT16)
  imagenet_eval = imagenet_input.ImageNetInput(
      is_training=False, data_dir=FLAGS.data, batch_size=batch_size,
      use_bfloat16=USE_BFLOAT16)

  lr_schedule_cb = LearningRateBatchScheduler(
      schedule=learning_rate_schedule_wrapper(training_steps_per_epoch))
  tensorboard_cb = eval_utils.TensorBoardWithValidation(
      log_dir=model_dir,
      validation_imagenet_input=imagenet_eval,
      validation_steps=validation_steps,
      validation_epochs=[30, 60, 90])

  training_callbacks = [lr_schedule_cb, tensorboard_cb]

  model.fit(
      imagenet_train.input_fn(),
      epochs=EPOCHS,
      steps_per_epoch=training_steps_per_epoch,
      callbacks=training_callbacks)

  model_saving_utils.save_model(model, model_dir, WEIGHTS_TXT)
Esempio n. 2
0
def main(unused_argv):
  assert FLAGS.data is not None, 'Provide training data path via --data.'

  batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE

  training_steps_per_epoch = FLAGS.steps_per_epoch or (
      int(APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
  validation_steps = int(IMAGENET_VALIDATION_IMAGES // batch_size)

  model_dir = FLAGS.model_dir
  logging.info('Saving tensorboard summaries at %s', model_dir)

  logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else 'local')
  logging.info('Use bfloat16: %s.', USE_BFLOAT16)
  logging.info('Use global batch size: %s.', batch_size)
  logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)
  logging.info('Training model using data in directory "%s".', FLAGS.data)

  logging.info('Building Keras ResNet-50 model')
  # tpu_model = resnet_model.ResNet50(num_classes=NUM_CLASSES)

  base_model = tf.keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet', input_shape=(224,224,3), classes=NUM_CLASSES)
  print(base_model)
  for layer in base_model.layers:
      layer.trainable = False

  x=base_model.output
  x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x)
  x = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
  tpu_model = Model(base_model.input, x)

  # tpu_model.load_weights("model/model.ckpt-112603")

  strategy=tf.contrib.tpu.TPUDistributionStrategy(tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR']))
  tpu_model = tf.contrib.tpu.keras_to_tpu_model(tpu_model,strategy)
  
  logging.info('Compiling model.')
  metrics = ['sparse_categorical_accuracy']

  if FLAGS.eval_top_5_accuracy:
    metrics.append(sparse_top_k_categorical_accuracy)

  tpu_model.compile(
        optimizer=optimizers.SGD(lr=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
        loss='sparse_categorical_crossentropy',
        metrics=metrics)

  imagenet_train = imagenet_input.ImageNetInput(
      is_training=True, data_dir=FLAGS.data, batch_size=batch_size,
      use_bfloat16=USE_BFLOAT16)
  imagenet_eval = imagenet_input.ImageNetInput(
      is_training=False, data_dir=FLAGS.data, batch_size=batch_size,
      use_bfloat16=USE_BFLOAT16)

  lr_schedule_cb = LearningRateBatchScheduler(
      schedule=learning_rate_schedule_wrapper(training_steps_per_epoch))
  tensorboard_cb = tf.keras.callbacks.TensorBoard(
      log_dir=model_dir)

  # checkpoint_path = "model/model.ckpt-112603"
  #
  # # Create checkpoint callback
  # cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
  #                                                  save_weights_only=True,
  #                                                  verbose=1)

  training_callbacks = [lr_schedule_cb, tensorboard_cb]

  tpu_model.fit(
      imagenet_train.input_fn().make_one_shot_iterator(),
      epochs=EPOCHS,
      steps_per_epoch=training_steps_per_epoch,
      callbacks=training_callbacks,
      validation_data=imagenet_eval.input_fn().make_one_shot_iterator(),
      validation_steps=validation_steps)


  model_saving_utils.save_model(tpu_model, model_dir, WEIGHTS_TXT)
Esempio n. 3
0
def main(unused_argv):
    assert FLAGS.data is not None, 'Provide training data path via --data.'
    tf.enable_v2_behavior()
    tf.compat.v1.disable_eager_execution()  # todo

    batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE

    training_steps_per_epoch = FLAGS.steps_per_epoch or (int(
        APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
    validation_steps = int(
        math.ceil(1.0 * IMAGENET_VALIDATION_IMAGES / batch_size))

    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    logging.info('Saving tensorboard summaries at %s', model_dir)

    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

    logging.info('Use bfloat16: %s.', USE_BFLOAT16)
    logging.info('Use global batch size: %s.', batch_size)
    logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)
    logging.info('Training model using data in directory "%s".', FLAGS.data)

    with strategy.scope():
        logging.info('Building Keras ResNet-50 model')
        model = resnet_model.ResNet50(num_classes=NUM_CLASSES)
        # model = keras_applications.mobilenet_v2.MobileNetV2(classes=NUM_CLASSES, weights=None)

        logging.info('Compiling model.')
        metrics = ['sparse_categorical_accuracy']

        if FLAGS.eval_top_5_accuracy:
            metrics.append(sparse_top_k_categorical_accuracy)

        model.compile(optimizer=tf.keras.optimizers.SGD(
            learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
                      loss='sparse_categorical_crossentropy',
                      metrics=metrics)

    imagenet_train = imagenet_input.ImageNetInput(is_training=True,
                                                  data_dir=FLAGS.data,
                                                  batch_size=batch_size,
                                                  use_bfloat16=USE_BFLOAT16)
    imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                 data_dir=FLAGS.data,
                                                 batch_size=batch_size,
                                                 use_bfloat16=USE_BFLOAT16)

    lr_schedule_cb = LearningRateBatchScheduler(
        schedule=learning_rate_schedule_wrapper(training_steps_per_epoch))
    tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=model_dir)

    training_callbacks = [lr_schedule_cb, tensorboard_cb]

    model.fit(imagenet_train.input_fn(),
              epochs=FLAGS.num_epochs,
              steps_per_epoch=training_steps_per_epoch,
              callbacks=training_callbacks,
              validation_data=imagenet_eval.input_fn(),
              validation_steps=validation_steps,
              validation_freq=5)

    model_saving_utils.save_model(model, model_dir, WEIGHTS_TXT)