def from_name(cls, model_name: str, model_weights_path: Optional[str] = None, checkpoint_format: Optional[str] = 'tf_checkpoint', overrides: Optional[Mapping[str, Any]] = None): """Constructs an MobilenetEdgeTPUV2 model from a predefined model name. E.g., `MobilenetEdgeTPUV2.from_name('mobilenet_edgetpu_v2_s')`. Args: model_name: the predefined model name model_weights_path: the path to the weights (h5 file or saved model dir) checkpoint_format: the model weights format. One of 'tf_checkpoint' or 'keras_checkpoint'. overrides: (optional) a dict containing keys that can override config Returns: A constructed EfficientNet instance. """ overrides = dict(overrides) if overrides else {} # One can define their own custom models if necessary MODEL_CONFIGS.update(overrides.pop('model_config', {})) model = cls(model_config_name=model_name, overrides=overrides) if model_weights_path: common_modules.load_weights(model, model_weights_path, checkpoint_format=checkpoint_format) return model
def run_export(): """Exports TFLite with PTQ.""" export_config = get_export_config_from_flags() model = export_util.build_experiment_model( experiment_type=export_config.model_name) if export_config.ckpt_path: logging.info('Loading checkpoint from %s', FLAGS.ckpt_path) common_modules.load_weights( model, export_config.ckpt_path, checkpoint_format=export_config.ckpt_format) else: logging.info( 'No checkpoint provided. Using randomly initialized weights.') if export_config.output_layer is not None: all_layer_names = {l.name for l in model.layers} if export_config.output_layer not in all_layer_names: model.summary() logging.info( 'Cannot find the layer %s in the model. See the above summary to ' 'chose an output layer.', export_config.output_layer) return output_layer = model.get_layer(export_config.output_layer) model = tf.keras.Model(model.input, output_layer.output) model_input = tf.keras.Input(shape=(export_config.image_size, export_config.image_size, 3), batch_size=1) model_output = export_util.finalize_serving(model(model_input), export_config) model_for_inference = tf.keras.Model(model_input, model_output) # Convert to tflite. Quantize if quantization parameters are specified. converter = tf.lite.TFLiteConverter.from_keras_model(model_for_inference) export_util.configure_tflite_converter(export_config, converter) tflite_buffer = converter.convert() # Make sure the base directory exists and write tflite. tf.io.gfile.makedirs(os.path.dirname(export_config.output_dir)) tflite_path = os.path.join(export_config.output_dir, f'{export_config.model_name}.tflite') tf.io.gfile.GFile(tflite_path, 'wb').write(tflite_buffer) print('TfLite model exported to {}'.format(tflite_path)) # Export saved model. saved_model_path = os.path.join(export_config.output_dir, export_config.model_name) if FLAGS.export_keras_model: model_for_inference.save(saved_model_path) else: tf.saved_model.save(model_for_inference, saved_model_path) print('SavedModel exported to {}'.format(saved_model_path))
def from_name(cls, model_name: str, model_weights_path: Optional[str] = None, checkpoint_format: Optional[str] = 'tf_checkpoint', overrides: Optional[Dict[str, Any]] = None): """Construct an MobilenetEdgeTPU model from a predefined model name. E.g., `MobilenetEdgeTPU.from_name('mobilenet_edgetpu')`. Args: model_name: the predefined model name model_weights_path: the path to the weights (h5 file or saved model dir) checkpoint_format: the model weights format. One of 'tf_checkpoint' or 'keras_checkpoint'. overrides: (optional) a dict containing keys that can override config Returns: A constructed EfficientNet instance. """ model_configs = dict(MODEL_CONFIGS) overrides = dict(overrides) if overrides else {} # One can define their own custom models if necessary model_configs.update(overrides.pop('model_config', {})) if model_name not in model_configs: raise ValueError('Unknown model name {}'.format(model_name)) config = model_configs[model_name] model = cls(config=config, overrides=overrides) if model_weights_path: common_modules.load_weights(model, model_weights_path, checkpoint_format=checkpoint_format) return model