Пример #1
0
def get_keras_model(model_type,
                    output_dimension,
                    truncate_output=False,
                    frontend=True,
                    tflite=False,
                    spec_augment=False):
    """Make a Keras student model."""
    # For debugging, log hyperparameter values.
    logging.info('model name: %s', model_type)
    logging.info('truncate_output: %s', truncate_output)
    logging.info('output_dimension: %i', output_dimension)
    logging.info('frontend: %s', frontend)
    logging.info('tflite: %s', tflite)
    logging.info('spec_augment: %s', spec_augment)

    output_dict = {}  # Dictionary of model outputs.

    # Construct model input and frontend.
    model_in, feats = _frontend_keras(frontend, tflite)
    feats.shape.assert_is_compatible_with([None, None, None, 1])
    spec_augment_fn = augmentation.SpecAugment(
    ) if spec_augment else tf.identity
    feats = spec_augment_fn(feats)

    inputs = [model_in]
    logging.info('Features shape: %s', feats.shape)

    # Build network.
    model_out = _build_main_net(model_type, feats)
    embeddings = tf.keras.layers.Flatten(name='distilled_output')(model_out)

    # The last fully-connected layer can sometimes be the single largest
    # layer in the entire network. It's also not always very valuable. We try
    # two methods of getting the right output dimension:
    # 1) A FC layer
    # 2) Taking the first `output_dimension` elements.
    need_final_layer = (output_dimension
                        and embeddings.shape[1] != output_dimension)

    # If we need to truncate, do it before we save the embedding. Otherwise,
    # the embedding will contain some garbage dimensions.
    if need_final_layer and truncate_output:
        if embeddings.shape[1] < output_dimension:
            embeddings = tf.pad(
                embeddings,
                [[0, 0], [0, output_dimension - embeddings.shape[1]]])
        else:
            embeddings = embeddings[:, :output_dimension]

    # Construct optional final layer, and create output dictionary.
    output_dict['embedding'] = embeddings

    target = embeddings
    if need_final_layer and not truncate_output:
        target = tf.keras.layers.Dense(output_dimension,
                                       name='embedding_to_target')(target)
    output_dict['embedding_to_target'] = target
    output_model = tf.keras.Model(inputs=inputs, outputs=output_dict)

    return output_model
    def test_spec_augment_training(self):
        """Verify augmentaion occurs during training."""
        input_tensor_shape = [3, 96, 64, 1]  # log Mel spectrogram.
        input_tensor = tf.ones(input_tensor_shape, dtype=tf.float32)
        m = tf.keras.Sequential(
            [tf.keras.layers.Input((96, 64, 1)),
             augmentation.SpecAugment()])

        out = m(input_tensor, training=True)
        self.assertListEqual(list(out.shape), input_tensor_shape)
        self.assertNotAllEqual(out, input_tensor)
    def test_spec_augment_inference(self):
        """Verify inference does not do augmentation."""
        input_tensor_shape = [3, 96, 64, 1]  # log Mel spectrogram.
        input_tensor = tf.ones(input_tensor_shape, dtype=tf.float32)
        m = tf.keras.Sequential(
            [tf.keras.layers.Input((96, 64, 1)),
             augmentation.SpecAugment()])

        out = m(input_tensor, training=False)
        self.assertListEqual(list(out.shape), input_tensor_shape)
        self.assertAllEqual(out, input_tensor)