Beispiel #1
0
def yamnet_frames_model_transfer(params, last_layers):
    """Defines the YAMNet waveform-to-class-scores model.

    Args:
      params: An instance of Params containing hyperparameters;
      last_layers: Path to the classifier model.
    Returns:
    A model accepting (num_samples,) waveform input and emitting:
    - predictions: (num_patches, num_classes) matrix of class scores per time frame
    - embeddings: (num_patches, embedding size) matrix of embeddings per time frame
    """

    waveform = layers.Input(batch_shape=(None, ), dtype=tf.float32)
    waveform_padded = features_lib.pad_waveform(waveform, params)
    log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
        waveform_padded, params)
    embeddings = yamnet_transfer(features, params)
    prediction = embeddings
    last_layers = load_model(last_layers)
    for layer in last_layers.layers[1:]:
        prediction = layer(prediction)
    frames_model = Model(name='yamnet_frames',
                         inputs=waveform,
                         outputs=[prediction, embeddings])
    return frames_model
Beispiel #2
0
def yamnet_frames_model(params):
  """Defines the YAMNet waveform-to-class-scores model.

  Args:
    params: An instance of Params containing hyperparameters.

  Returns:
    A model accepting (num_samples,) waveform input and emitting:
    - predictions: (num_patches, num_classes) matrix of class scores per time frame
    - embeddings: (num_patches, embedding size) matrix of embeddings per time frame
    - log_mel_spectrogram: (num_spectrogram_frames, num_mel_bins) spectrogram feature matrix
  """
  waveform = layers.Input(batch_shape=(None,), dtype=tf.float32)
  waveform_padded = features_lib.pad_waveform(waveform, params)
  log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
      waveform_padded, params)
  predictions, embeddings = yamnet(features, params)
  frames_model = Model(
      name='yamnet_frames', inputs=waveform,
      outputs=[predictions, embeddings, log_mel_spectrogram])
  return frames_model