def get_keras_model(model_type, output_dimension, truncate_output=False, frontend=True, tflite=False, spec_augment=False): """Make a Keras student model.""" # For debugging, log hyperparameter values. logging.info('model name: %s', model_type) logging.info('truncate_output: %s', truncate_output) logging.info('output_dimension: %i', output_dimension) logging.info('frontend: %s', frontend) logging.info('tflite: %s', tflite) logging.info('spec_augment: %s', spec_augment) output_dict = {} # Dictionary of model outputs. # Construct model input and frontend. model_in, feats = _frontend_keras(frontend, tflite) feats.shape.assert_is_compatible_with([None, None, None, 1]) spec_augment_fn = augmentation.SpecAugment( ) if spec_augment else tf.identity feats = spec_augment_fn(feats) inputs = [model_in] logging.info('Features shape: %s', feats.shape) # Build network. model_out = _build_main_net(model_type, feats) embeddings = tf.keras.layers.Flatten(name='distilled_output')(model_out) # The last fully-connected layer can sometimes be the single largest # layer in the entire network. It's also not always very valuable. We try # two methods of getting the right output dimension: # 1) A FC layer # 2) Taking the first `output_dimension` elements. need_final_layer = (output_dimension and embeddings.shape[1] != output_dimension) # If we need to truncate, do it before we save the embedding. Otherwise, # the embedding will contain some garbage dimensions. if need_final_layer and truncate_output: if embeddings.shape[1] < output_dimension: embeddings = tf.pad( embeddings, [[0, 0], [0, output_dimension - embeddings.shape[1]]]) else: embeddings = embeddings[:, :output_dimension] # Construct optional final layer, and create output dictionary. output_dict['embedding'] = embeddings target = embeddings if need_final_layer and not truncate_output: target = tf.keras.layers.Dense(output_dimension, name='embedding_to_target')(target) output_dict['embedding_to_target'] = target output_model = tf.keras.Model(inputs=inputs, outputs=output_dict) return output_model
def test_spec_augment_training(self): """Verify augmentaion occurs during training.""" input_tensor_shape = [3, 96, 64, 1] # log Mel spectrogram. input_tensor = tf.ones(input_tensor_shape, dtype=tf.float32) m = tf.keras.Sequential( [tf.keras.layers.Input((96, 64, 1)), augmentation.SpecAugment()]) out = m(input_tensor, training=True) self.assertListEqual(list(out.shape), input_tensor_shape) self.assertNotAllEqual(out, input_tensor)
def test_spec_augment_inference(self): """Verify inference does not do augmentation.""" input_tensor_shape = [3, 96, 64, 1] # log Mel spectrogram. input_tensor = tf.ones(input_tensor_shape, dtype=tf.float32) m = tf.keras.Sequential( [tf.keras.layers.Input((96, 64, 1)), augmentation.SpecAugment()]) out = m(input_tensor, training=False) self.assertListEqual(list(out.shape), input_tensor_shape) self.assertAllEqual(out, input_tensor)