コード例 #1
0
 def preprocess_map_fn(images, labels):
   del labels
   model_builder = model_builder_factory.get_model_builder(FLAGS.model_name)
   images -= tf.constant(
       model_builder.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype)
   images /= tf.constant(
       model_builder.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype)
   return images
コード例 #2
0
def main(_):
  # Enables eager context for TF 1.x. TF 2.x will use eager by default.
  # This is used to conveniently get a representative dataset generator using
  # TensorFlow training input helper.
  tf.enable_eager_execution()

  model_builder = model_builder_factory.get_model_builder(FLAGS.model_name)

  with tf.Graph().as_default(), tf.Session() as sess:
    images = tf.placeholder(
        tf.float32,
        shape=(1, FLAGS.image_size, FLAGS.image_size, 3),
        name="images")

    logits, endpoints = model_builder.build_model(images, FLAGS.model_name,
                                                  False)
    if FLAGS.endpoint_name:
      output_tensor = endpoints[FLAGS.endpoint_name]
    else:
      output_tensor = tf.nn.softmax(logits)

    restore_model(sess, FLAGS.ckpt_dir, FLAGS.enable_ema)

    if FLAGS.output_saved_model_dir:
      signature_def_map = {
          "serving_default":
              tf.compat.v1.saved_model.signature_def_utils
              .predict_signature_def({"input": images},
                                     {"output": output_tensor})
      }

      builder = tf.compat.v1.saved_model.Builder(FLAGS.output_saved_model_dir)
      builder.add_meta_graph_and_variables(
          sess, ["serve"], signature_def_map=signature_def_map)
      builder.save()
      print("Saved model written to %s" % FLAGS.output_saved_model_dir)

    converter = tf.lite.TFLiteConverter.from_session(sess, [images],
                                                     [output_tensor])
    if FLAGS.quantize:
      if not FLAGS.data_dir:
        raise ValueError(
            "Post training quantization requires data_dir flag to point to the "
            "calibration dataset. To export a float model, set "
            "--quantize=False.")

      converter.representative_dataset = tf.lite.RepresentativeDataset(
          representative_dataset_gen)
      converter.optimizations = [tf.lite.Optimize.DEFAULT]
      converter.inference_input_type = tf.lite.constants.QUANTIZED_UINT8
      converter.inference_output_type = tf.lite.constants.QUANTIZED_UINT8
      converter.target_spec.supported_ops = [
          tf.lite.OpsSet.TFLITE_BUILTINS_INT8
      ]

  tflite_buffer = converter.convert()
  tf.gfile.GFile(FLAGS.output_tflite, "wb").write(tflite_buffer)
  print("tflite model written to %s" % FLAGS.output_tflite)
コード例 #3
0
 def build_model():
     """Build model using the model_name given through the command line."""
     model_builder = model_builder_factory.get_model_builder(
         FLAGS.model_name)
     normalized_features = normalize_features(features,
                                              model_builder.MEAN_RGB,
                                              model_builder.STDDEV_RGB)
     logits, _ = model_builder.build_model(normalized_features,
                                           model_name=FLAGS.model_name,
                                           training=is_training,
                                           override_params=override_params,
                                           model_dir=FLAGS.model_dir)
     return logits
コード例 #4
0
    def build_model(self, features, is_training):
        """Build model with input features."""
        tf.logging.info(self.model_name)
        model_builder = model_builder_factory.get_model_builder(
            self.model_name)

        if self.advprop_preprocessing:
            # AdvProp uses Inception preprocessing.
            features = features * 2.0 / 255 - 1.0
        else:
            features -= tf.constant(model_builder.MEAN_RGB,
                                    shape=[1, 1, 3],
                                    dtype=features.dtype)
            features /= tf.constant(model_builder.STDDEV_RGB,
                                    shape=[1, 1, 3],
                                    dtype=features.dtype)
        logits, _ = model_builder.build_model(features, self.model_name,
                                              is_training)
        probs = tf.nn.softmax(logits)
        probs = tf.squeeze(probs)
        return probs
コード例 #5
0
import model_builder_factory
import tensorflow as tf
import keras

model_builder = model_builder_factory.get_model_builder("efficientnet-lite0")


def restore_model(sess, ckpt_dir, enable_ema=True):
    sess.run(tf.global_variables_initializer())
    checkpoint = tf.train.latest_checkpoint(ckpt_dir)
    checkpoint = "./checkpoint_lite0/model.ckpt-2893563"
    print(checkpoint)
    if enable_ema:
        ema = tf.train.ExponentialMovingAverage(decay=0.0)
        ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
        for v in tf.global_variables():
            if "moving_mean" in v.name or "moving_variance" in v.name:
                ema_vars.append(v)
        ema_vars = list(set(ema_vars))
        var_dict = ema.variables_to_restore(ema_vars)
    else:
        var_dict = None

    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_dict, max_to_keep=1)
    saver.restore(sess, checkpoint)
    tf_session = keras.backend.get_session()
    input_graph_def = tf_session.graph.as_graph_def()
    save_path = saver.save(tf_session, './checkpoint.ckpt')
    tf.train.write_graph(input_graph_def,
                         './',