示例#1
0
def encode(wav_data, checkpoint_path, sample_length=64000):
    """Padded loading of a wave file.

  Args:
    wav_data: Numpy array [batch_size, sample_length]
    checkpoint_path: Location of the pretrained model.
    sample_length: The total length of the final wave file, padded with 0s.
  Returns:
    encoding: a [mb, 125, 16] encoding (for 64000 sample audio file).
    hop_length: Pooling size of the autoencoder.
  """
    if wav_data.ndim == 1:
        wav_data = np.expand_dims(wav_data, 0)
        batch_size = 1
    elif wav_data.ndim == 2:
        batch_size = wav_data.shape[0]

    # Load up the model for encoding and find the encoding of "wav_data"
    session_config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Graph().as_default(), tf.Session(config=session_config) as sess:
        hop_length = Config().ae_hop_length
        wav_data, sample_length = utils.trim_for_encoding(
            wav_data, sample_length, hop_length)
        net = load_nsynth(batch_size=batch_size, sample_length=sample_length)
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint_path)
        encoding = sess.run(net["encoding"], feed_dict={net["X"]: wav_data})
    return encoding, hop_length
示例#2
0
def encode(wav_data, checkpoint_path, sample_length=64000):
  """Generate an array of embeddings from an array of audio.

  Args:
    wav_data: Numpy array [batch_size, sample_length]
    checkpoint_path: Location of the pretrained model.
    sample_length: The total length of the final wave file, padded with 0s.
  Returns:
    encoding: a [mb, 125, 16] encoding (for 64000 sample audio file).
  """
  if wav_data.ndim == 1:
    wav_data = np.expand_dims(wav_data, 0)
    batch_size = 1
  elif wav_data.ndim == 2:
    batch_size = wav_data.shape[0]

  # Load up the model for encoding and find the encoding of "wav_data"
  session_config = tf.ConfigProto(allow_soft_placement=True)
  with tf.Graph().as_default(), tf.Session(config=session_config) as sess:
    hop_length = Config().ae_hop_length
    wav_data, sample_length = utils.trim_for_encoding(wav_data, sample_length,
                                                      hop_length)
    net = load_nsynth(batch_size=batch_size, sample_length=sample_length)
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)
    encodings = sess.run(net["encoding"], feed_dict={net["X"]: wav_data})
  return encodings