def preprocess_map_fn(images, labels): del labels model_builder = model_builder_factory.get_model_builder(FLAGS.model_name) images -= tf.constant( model_builder.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype) images /= tf.constant( model_builder.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype) return images
def main(_): # Enables eager context for TF 1.x. TF 2.x will use eager by default. # This is used to conveniently get a representative dataset generator using # TensorFlow training input helper. tf.enable_eager_execution() model_builder = model_builder_factory.get_model_builder(FLAGS.model_name) with tf.Graph().as_default(), tf.Session() as sess: images = tf.placeholder( tf.float32, shape=(1, FLAGS.image_size, FLAGS.image_size, 3), name="images") logits, endpoints = model_builder.build_model(images, FLAGS.model_name, False) if FLAGS.endpoint_name: output_tensor = endpoints[FLAGS.endpoint_name] else: output_tensor = tf.nn.softmax(logits) restore_model(sess, FLAGS.ckpt_dir, FLAGS.enable_ema) if FLAGS.output_saved_model_dir: signature_def_map = { "serving_default": tf.compat.v1.saved_model.signature_def_utils .predict_signature_def({"input": images}, {"output": output_tensor}) } builder = tf.compat.v1.saved_model.Builder(FLAGS.output_saved_model_dir) builder.add_meta_graph_and_variables( sess, ["serve"], signature_def_map=signature_def_map) builder.save() print("Saved model written to %s" % FLAGS.output_saved_model_dir) converter = tf.lite.TFLiteConverter.from_session(sess, [images], [output_tensor]) if FLAGS.quantize: if not FLAGS.data_dir: raise ValueError( "Post training quantization requires data_dir flag to point to the " "calibration dataset. To export a float model, set " "--quantize=False.") converter.representative_dataset = tf.lite.RepresentativeDataset( representative_dataset_gen) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.inference_input_type = tf.lite.constants.QUANTIZED_UINT8 converter.inference_output_type = tf.lite.constants.QUANTIZED_UINT8 converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ] tflite_buffer = converter.convert() tf.gfile.GFile(FLAGS.output_tflite, "wb").write(tflite_buffer) print("tflite model written to %s" % FLAGS.output_tflite)
def build_model(): """Build model using the model_name given through the command line.""" model_builder = model_builder_factory.get_model_builder( FLAGS.model_name) normalized_features = normalize_features(features, model_builder.MEAN_RGB, model_builder.STDDEV_RGB) logits, _ = model_builder.build_model(normalized_features, model_name=FLAGS.model_name, training=is_training, override_params=override_params, model_dir=FLAGS.model_dir) return logits
def build_model(self, features, is_training): """Build model with input features.""" tf.logging.info(self.model_name) model_builder = model_builder_factory.get_model_builder( self.model_name) if self.advprop_preprocessing: # AdvProp uses Inception preprocessing. features = features * 2.0 / 255 - 1.0 else: features -= tf.constant(model_builder.MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(model_builder.STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) logits, _ = model_builder.build_model(features, self.model_name, is_training) probs = tf.nn.softmax(logits) probs = tf.squeeze(probs) return probs
import model_builder_factory import tensorflow as tf import keras model_builder = model_builder_factory.get_model_builder("efficientnet-lite0") def restore_model(sess, ckpt_dir, enable_ema=True): sess.run(tf.global_variables_initializer()) checkpoint = tf.train.latest_checkpoint(ckpt_dir) checkpoint = "./checkpoint_lite0/model.ckpt-2893563" print(checkpoint) if enable_ema: ema = tf.train.ExponentialMovingAverage(decay=0.0) ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars") for v in tf.global_variables(): if "moving_mean" in v.name or "moving_variance" in v.name: ema_vars.append(v) ema_vars = list(set(ema_vars)) var_dict = ema.variables_to_restore(ema_vars) else: var_dict = None sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(var_dict, max_to_keep=1) saver.restore(sess, checkpoint) tf_session = keras.backend.get_session() input_graph_def = tf_session.graph.as_graph_def() save_path = saver.save(tf_session, './checkpoint.ckpt') tf.train.write_graph(input_graph_def, './',