Пример #1
0
def get_keras_model(bottleneck_dimension,
                    output_dimension,
                    alpha=1.0,
                    mobilenet_size='small',
                    frontend=True,
                    avg_pool=False,
                    compressor=None,
                    quantize_aware_training=False,
                    tflite=False):
    """Make a Keras student model."""
    # For debugging, log hyperparameter values.
    logging.info('bottleneck_dimension: %i', bottleneck_dimension)
    logging.info('output_dimension: %i', output_dimension)
    logging.info('alpha: %s', alpha)
    logging.info('frontend: %s', frontend)
    logging.info('avg_pool: %s', avg_pool)
    logging.info('compressor: %s', compressor)
    logging.info('quantize_aware_training: %s', quantize_aware_training)
    logging.info('tflite: %s', tflite)

    output_dict = {}  # Dictionary of model outputs.

    def _map_mobilenet_func(mnet_size):
        mnet_size_map = {
            'tiny': mobilenetv3_tiny,
            'small': tf.keras.applications.MobileNetV3Small,
            'large': tf.keras.applications.MobileNetV3Large,
        }
        if mnet_size.lower() not in mnet_size_map:
            raise ValueError('Unknown MobileNet size %s.' % mnet_size)
        return mnet_size_map[mnet_size.lower()]

    # TFLite use-cases usually use non-batched inference, and this also enables
    # hardware acceleration.
    num_batches = 1 if tflite else None
    if frontend:
        frontend_args = tf_frontend.frontend_args_from_flags()
        logging.info('frontend_args: %s', frontend_args)
        model_in = tf.keras.Input((None, ),
                                  name='audio_samples',
                                  batch_size=num_batches)
        frontend_fn = _get_feats_map_fn(tflite, frontend_args)
        feats = tf.keras.layers.Lambda(frontend_fn)(model_in)
        feats = tf.reshape(feats, [-1, 96, 64, 1])
    else:
        model_in = tf.keras.Input((96, 64, 1), name='log_mel_spectrogram')
        feats = model_in
    inputs = [model_in]

    model = _map_mobilenet_func(mobilenet_size)(
        input_shape=[96, 64, 1],
        alpha=alpha,
        minimalistic=False,
        include_top=False,
        weights=None,
        pooling='avg' if avg_pool else None,
        dropout_rate=0.0)
    model_out = model(feats)
    if avg_pool:
        model_out.shape.assert_is_compatible_with([None, None])
    else:
        model_out.shape.assert_is_compatible_with([None, 1, 1, None])
    if bottleneck_dimension:
        if compressor is not None:
            bottleneck = CompressedDense(bottleneck_dimension,
                                         compression_obj=compressor,
                                         name='distilled_output')
        else:
            bottleneck = tf.keras.layers.Dense(bottleneck_dimension,
                                               name='distilled_output')
            if quantize_aware_training:
                bottleneck = tfmot.quantization.keras.quantize_annotate_layer(
                    bottleneck)
        embeddings = tf.keras.layers.Flatten()(model_out)
        embeddings = bottleneck(embeddings)
    else:
        embeddings = tf.keras.layers.Flatten(
            name='distilled_output')(model_out)

    # Construct optional final layer, and create output dictionary.
    output_dict['embedding'] = embeddings
    if output_dimension:
        output = tf.keras.layers.Dense(output_dimension,
                                       name='embedding_to_target')(embeddings)
        output_dict['embedding_to_target'] = output
    output_model = tf.keras.Model(inputs=inputs, outputs=output_dict)

    # Optional modifications to the model for TFLite.
    if tflite:
        if compressor is not None:
            # If model employs compression, this ensures that the TFLite model
            # just uses the smaller matrices for inference.
            output_model.get_layer('distilled_output').kernel = None
            output_model.get_layer(
                'distilled_output').compression_op.a_matrix_tfvar = None

    return output_model
Пример #2
0
def get_frontend_output_shape():
    frontend_args = tf_frontend.frontend_args_from_flags()
    x = tf.zeros([frontend_args['n_required']], dtype=tf.float32)
    return _sample_to_features(x, frontend_args, tflite=False).shape
Пример #3
0
def _sample_to_features(x, export_tflite=False):
    frontend_args = tf_frontend.frontend_args_from_flags()
    return tf_frontend.compute_frontend_features(x,
                                                 16000,
                                                 tflite=export_tflite,
                                                 **frontend_args)
Пример #4
0
def default_feature_fn(samples, sample_rate):
    frontend_args = tf_frontend.frontend_args_from_flags()
    feats = tf_frontend.compute_frontend_features(samples, sample_rate,
                                                  **frontend_args)
    logging.info('Feats shape: %s', feats.shape)
    return tf.expand_dims(feats, axis=-1).numpy().astype(np.float32)
Пример #5
0
def _default_feature_fn(samples, sample_rate):
    frontend_args = tf_frontend.frontend_args_from_flags()
    feats = tf_frontend.compute_frontend_features(samples, sample_rate,
                                                  **frontend_args)
    return tf.expand_dims(feats, axis=-1).numpy().astype(np.float32)