def main(args):
  """Main function to be called by absl.app.run() after flag parsing."""
  del args
  _check_keras_dependencies()
  hparams = _get_hparams_from_flags()

  image_dir = FLAGS.image_dir or lib.get_default_image_dir()

  model, labels, train_result = lib.make_image_classifier(
      FLAGS.tfhub_module, image_dir, hparams, FLAGS.image_size)
  if FLAGS.assert_accuracy_at_least:
    _assert_accuracy(train_result, FLAGS.assert_accuracy_at_least)
  print("Done with training.")

  if FLAGS.labels_output_file:
    with tf.io.gfile.GFile(FLAGS.labels_output_file, "w") as f:
      f.write("\n".join(labels + ("",)))
    print("Labels written to", FLAGS.labels_output_file)

  saved_model_dir = FLAGS.saved_model_dir
  if FLAGS.tflite_output_file and not saved_model_dir:
    # We need a SavedModel for conversion, even if the user did not request it.
    saved_model_dir = tempfile.mkdtemp()
  if saved_model_dir:
    tf.saved_model.save(model, saved_model_dir)
    print("SavedModel model exported to", saved_model_dir)

  if FLAGS.tflite_output_file:
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
    lite_model_content = converter.convert()
    with tf.io.gfile.GFile(FLAGS.tflite_output_file, "wb") as f:
      f.write(lite_model_content)
    print("TFLite model exported to", FLAGS.tflite_output_file)
def main(args):
    """Main function to be called by absl.app.run() after flag parsing."""
    del args
    _check_keras_dependencies()
    hparams = _get_hparams_from_flags()

    image_dir = FLAGS.image_dir or lib.get_default_image_dir()

    if FLAGS.set_memory_growth:
        _set_gpu_memory_growth()

    use_tf_data_input = FLAGS.use_tf_data_input
    # For tensorflow<2.5 TF preprocessing layers do not support distribution
    # strategy. so default use_tf_data_input to True for TF >= 2.5.
    if use_tf_data_input is True and (LooseVersion(tf.__version__) <
                                      LooseVersion("2.5.0")):
        raise ValueError(
            "use_tf_data_input is not supported for tensorflow<2.5")
    # For tensorflow>=2.5 default to using tf.data.Dataset and TF preprocessing
    # layers.
    if use_tf_data_input is None and (LooseVersion(tf.__version__) >=
                                      LooseVersion("2.5.0")):
        use_tf_data_input = True

    model, labels, train_result = lib.make_image_classifier(
        FLAGS.tfhub_module, image_dir, hparams,
        lib.get_distribution_strategy(FLAGS.distribution_strategy),
        FLAGS.image_size, FLAGS.summaries_dir, use_tf_data_input)
    if FLAGS.assert_accuracy_at_least:
        _assert_accuracy(train_result, FLAGS.assert_accuracy_at_least)
    print("Done with training.")

    if FLAGS.labels_output_file:
        with tf.io.gfile.GFile(FLAGS.labels_output_file, "w") as f:
            f.write("\n".join(labels + ("", )))
        print("Labels written to", FLAGS.labels_output_file)

    saved_model_dir = FLAGS.saved_model_dir
    if FLAGS.tflite_output_file and not saved_model_dir:
        # We need a SavedModel for conversion, even if the user did not request it.
        saved_model_dir = tempfile.mkdtemp()
    if saved_model_dir:
        tf.saved_model.save(model, saved_model_dir)
        print("SavedModel model exported to", saved_model_dir)

    if FLAGS.tflite_output_file:
        converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
        lite_model_content = converter.convert()
        with tf.io.gfile.GFile(FLAGS.tflite_output_file, "wb") as f:
            f.write(lite_model_content)
        print("TFLite model exported to", FLAGS.tflite_output_file)
Exemplo n.º 3
0
    def retrain(self,
                image_dir,
                output_graph,
                training_steps,
                learning_rate,
                desired_training_accuracy=100.0,
                desired_validation_accuracy=100.0,
                flip_left_right=True,
                random_crop=30,
                random_scale=30,
                random_brightness=30):

        _check_keras_dependencies()
        hparams = _get_hparams(train_epochs=training_steps,
                               learning_rate=learning_rate)

        image_size = self.shape[0]
        tfhub_module = "https://tfhub.dev/google/" + self.architecture
        model, labels, train_result = lib.make_image_classifier(
            tfhub_module, image_dir, hparams, image_size)
        print("Done with training.")

        labels_output_file = output_graph + ".txt"
        with tf.io.gfile.GFile(labels_output_file, "w") as f:
            f.write("\n".join(labels + ("", )))
            print("Labels written to", labels_output_file)

        saved_model_dir = tempfile.mkdtemp()
        tf.saved_model.save(model, saved_model_dir)
        print("SavedModel model exported to", saved_model_dir)

        tflite_output_file = output_graph + ".tflite"
        converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
        converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_LATENCY]
        lite_model_content = converter.convert()
        with tf.io.gfile.GFile(tflite_output_file, "wb") as f:
            f.write(lite_model_content)
            print("TFLite model exported to", tflite_output_file)