def get_keras_model(bottleneck_dimension, output_dimension, alpha=1.0, mobilenet_size='small', frontend=True, avg_pool=False, compressor=None, quantize_aware_training=False, tflite=False): """Make a Keras student model.""" # For debugging, log hyperparameter values. logging.info('bottleneck_dimension: %i', bottleneck_dimension) logging.info('output_dimension: %i', output_dimension) logging.info('alpha: %s', alpha) logging.info('frontend: %s', frontend) logging.info('avg_pool: %s', avg_pool) logging.info('compressor: %s', compressor) logging.info('quantize_aware_training: %s', quantize_aware_training) logging.info('tflite: %s', tflite) output_dict = {} # Dictionary of model outputs. def _map_mobilenet_func(mnet_size): mnet_size_map = { 'tiny': mobilenetv3_tiny, 'small': tf.keras.applications.MobileNetV3Small, 'large': tf.keras.applications.MobileNetV3Large, } if mnet_size.lower() not in mnet_size_map: raise ValueError('Unknown MobileNet size %s.' % mnet_size) return mnet_size_map[mnet_size.lower()] # TFLite use-cases usually use non-batched inference, and this also enables # hardware acceleration. num_batches = 1 if tflite else None if frontend: frontend_args = tf_frontend.frontend_args_from_flags() logging.info('frontend_args: %s', frontend_args) model_in = tf.keras.Input((None, ), name='audio_samples', batch_size=num_batches) frontend_fn = _get_feats_map_fn(tflite, frontend_args) feats = tf.keras.layers.Lambda(frontend_fn)(model_in) feats = tf.reshape(feats, [-1, 96, 64, 1]) else: model_in = tf.keras.Input((96, 64, 1), name='log_mel_spectrogram') feats = model_in inputs = [model_in] model = _map_mobilenet_func(mobilenet_size)( input_shape=[96, 64, 1], alpha=alpha, minimalistic=False, include_top=False, weights=None, pooling='avg' if avg_pool else None, dropout_rate=0.0) model_out = model(feats) if avg_pool: model_out.shape.assert_is_compatible_with([None, None]) else: model_out.shape.assert_is_compatible_with([None, 1, 1, None]) if bottleneck_dimension: if compressor is not None: bottleneck = CompressedDense(bottleneck_dimension, compression_obj=compressor, name='distilled_output') else: bottleneck = tf.keras.layers.Dense(bottleneck_dimension, name='distilled_output') if quantize_aware_training: bottleneck = tfmot.quantization.keras.quantize_annotate_layer( bottleneck) embeddings = tf.keras.layers.Flatten()(model_out) embeddings = bottleneck(embeddings) else: embeddings = tf.keras.layers.Flatten( name='distilled_output')(model_out) # Construct optional final layer, and create output dictionary. output_dict['embedding'] = embeddings if output_dimension: output = tf.keras.layers.Dense(output_dimension, name='embedding_to_target')(embeddings) output_dict['embedding_to_target'] = output output_model = tf.keras.Model(inputs=inputs, outputs=output_dict) # Optional modifications to the model for TFLite. if tflite: if compressor is not None: # If model employs compression, this ensures that the TFLite model # just uses the smaller matrices for inference. output_model.get_layer('distilled_output').kernel = None output_model.get_layer( 'distilled_output').compression_op.a_matrix_tfvar = None return output_model
def get_frontend_output_shape(): frontend_args = tf_frontend.frontend_args_from_flags() x = tf.zeros([frontend_args['n_required']], dtype=tf.float32) return _sample_to_features(x, frontend_args, tflite=False).shape
def _sample_to_features(x, export_tflite=False): frontend_args = tf_frontend.frontend_args_from_flags() return tf_frontend.compute_frontend_features(x, 16000, tflite=export_tflite, **frontend_args)
def default_feature_fn(samples, sample_rate): frontend_args = tf_frontend.frontend_args_from_flags() feats = tf_frontend.compute_frontend_features(samples, sample_rate, **frontend_args) logging.info('Feats shape: %s', feats.shape) return tf.expand_dims(feats, axis=-1).numpy().astype(np.float32)
def _default_feature_fn(samples, sample_rate): frontend_args = tf_frontend.frontend_args_from_flags() feats = tf_frontend.compute_frontend_features(samples, sample_rate, **frontend_args) return tf.expand_dims(feats, axis=-1).numpy().astype(np.float32)