Example #1
0
 def refresh_cache(self):
     if self.mode == "long":
         dur = self.cached_end - self.cached_begin
         self.signal, self.sr = librosa.load(self.path, sr=None, offset=self.cached_begin, duration=dur, mono=False)
         if len(self.signal.shape) == 1:
             self.signal = self.signal.reshape((self.signal.shape[0], 1))
         else:
             self.signal = self.signal.T
         self.preemph_signal = lfilter([1.0, -0.95], 1, self.signal, axis=0)
         self.downsampled_1000 = resample(self.signal, self.sr, 1000, filter="kaiser_fast", axis=0)
         self.downsampled_100 = resample(self.downsampled_1000, 1000, 100, filter="kaiser_fast", axis=0)
Example #2
0
def augment_audio_signal(signal, fs, augmentation):
  """Function that performs audio signal augmentation.

  Args:
    signal (np.array): np.array containing raw audio signal.
    fs (float): frames per second.
    augmentation (dict): dictionary of augmentation parameters. See
        :func:`get_speech_features_from_file` for specification and example.
  Returns:
    np.array: np.array with augmented audio signal.
  """
  signal_float = signal.astype(np.float32) / 32768.0

  if augmentation['time_stretch_ratio'] > 0:
    # time stretch (might be slow)
    stretch_amount = 1.0 + (2.0 * np.random.rand() - 1.0) * \
                     augmentation['time_stretch_ratio']
    signal_float = rs.resample(
      signal_float,
      fs,
      int(fs * stretch_amount),
      filter='kaiser_fast',
    )

  # noise
  noise_level_db = np.random.randint(low=augmentation['noise_level_min'],
                                     high=augmentation['noise_level_max'])
  signal_float += np.random.randn(signal_float.shape[0]) * \
                  10.0 ** (noise_level_db / 20.0)

  return (signal_float * 32768.0).astype(np.int16)
Example #3
0
def test_good_window():
    sr_orig = 100
    sr_new = 200
    x = np.random.randn(500)
    y = resampy.resample(x, sr_orig, sr_new, filter='sinc_window', window=scipy.signal.blackman)

    assert len(y) == 2 * len(x)
Example #4
0
def update_max_len(file_path_list, max_len):
    tmp_max_len = 0
    # Update the max length based on the given dataset
    signal_set = set()
    for file_path in file_path_list:
        file_list = open(file_path)
        for line in file_list:
            line = line.strip().split()
            if len(line) < 2:
                print 'Wrong audio list file record in the line:', line
                continue
            file_str = line[0]
            if file_str in signal_set:
                continue
            signal_set.add(file_str)
            signal, rate = sf.read(file_str)  # signal: sample values,rate: sample rate
            if len(signal.shape) > 1:
                signal = signal[:, 0]
            if rate != FRAME_RATE:
                # up-sample or down-sample for predefined sample rate
                signal = resampy.resample(signal, rate, FRAME_RATE, filter='kaiser_fast')
            if len(signal) > tmp_max_len:
                tmp_max_len = len(signal)
        file_list.close()
    if tmp_max_len < max_len:
        max_len = tmp_max_len
    return max_len
def create_tf_example(ds):
    
    waveform = ds[:]
    attrs = ds.attrs
    
    sample_rate = attrs['sample_rate']
    
    # Trim waveform.
    start_index = int(round(EXAMPLE_START_OFFSET * sample_rate))
    length = int(round(EXAMPLE_DURATION * sample_rate))
    waveform = waveform[start_index:start_index + length]
    
    # Resample if needed.
    if sample_rate != EXAMPLE_SAMPLE_RATE:
        waveform = resampy.resample(waveform, sample_rate, EXAMPLE_SAMPLE_RATE)
    
    waveform_feature = create_bytes_feature(waveform.tostring())
    
    classification = attrs['classification']
    label = 1 if classification.startswith('Call') else 0
    label_feature = create_int64_feature(label)
    
    clip_id = attrs['clip_id']
    clip_id_feature = create_int64_feature(clip_id)
    
    features = tf.train.Features(
        feature={
            'waveform': waveform_feature,
            'label': label_feature,
            'clip_id': clip_id_feature
        })
    
    return tf.train.Example(features=features)
Example #6
0
def load_bgd_wav(file_path):
    signal, rate = sf.read(file_path)  # signal: sample values,rate: sample rate
    if len(signal.shape) > 1:
        signal = signal[:, 0]
    if rate != FRAME_RATE:
        # up-sample or down-sample for predefined sample rate
        signal = resampy.resample(signal, rate, FRAME_RATE, filter='kaiser_fast')
    return signal
Example #7
0
    def __test(sr_orig, sr_new, fil, rms, x, y):

        y_pred = resampy.resample(x, sr_orig, sr_new, filter=fil)

        idx = slice(sr_new // 2, - sr_new//2)

        err = np.mean(np.abs(y[idx] - y_pred[idx]))
        assert err <= rms, '{:g} > {:g}'.format(err, rms)
Example #8
0
    def __test(axis, sr_orig, sr_new, X):

        Y = resampy.resample(X, sr_orig, sr_new, axis=axis)

        target_shape = list(X.shape)
        target_shape[axis] = target_shape[axis] * sr_new // sr_orig

        eq_(target_shape, list(Y.shape))
def test_resampy(samples, input_rate, output_rate, filter_name, pdf_file):
    
    # Resample chirp.
    samples = resampy.resample(
        samples, input_rate, output_rate, filter=filter_name)
    
    # Plot spectrogram of result.
    title = 'Resampy {}'.format(filter_name)
    plot_spectrogram(samples, output_rate, title, pdf_file)
Example #10
0
    async def to_f32_16k(wav : bytes) -> numpy.ndarray:
        # converting the wav to ndarray, which is much easier to use for DSP
        rate, data = wavfile.read(BytesIO(wav))
        # casting the data array to the right format (float32, for usage by pysndfx)
        data = (data / (2. ** 15)).astype('float32')
        if rate != BASE_SAMPLING_RATE:
            data = resample(data, rate, BASE_SAMPLING_RATE)

        return BASE_SAMPLING_RATE, data
Example #11
0
def test_shape(axis):
    sr_orig = 100
    sr_new = sr_orig // 2
    X = np.random.randn(sr_orig, sr_orig, sr_orig)
    Y = resampy.resample(X, sr_orig, sr_new, axis=axis)

    target_shape = list(X.shape)
    target_shape[axis] = target_shape[axis] * sr_new // sr_orig

    assert target_shape == list(Y.shape)
Example #12
0
def load_audio(path, sr):
    """
    Load audio file
    """
    data, sr_orig = sf.read(path, dtype='float32', always_2d=True)
    data = data.mean(axis=-1)

    if sr_orig != sr:
        data = resampy.resample(data, sr_orig, sr)

    return data
Example #13
0
def test_quality_sweep(sr_orig, sr_new, fil, rms):
    FREQ = 8192
    DURATION = 5.0
    x = make_sweep(FREQ, sr_orig, DURATION)
    y = make_sweep(FREQ, sr_new, DURATION)

    y_pred = resampy.resample(x, sr_orig, sr_new, filter=fil)

    idx = slice(sr_new // 2, - sr_new//2)

    err = np.mean(np.abs(y[idx] - y_pred[idx]))
    assert err <= rms, '{:g} > {:g}'.format(err, rms)
Example #14
0
    def __init__(self, sound_file, initial_begin=None, initial_end=None):
        self.path = sound_file.filepath
        self.mode = None

        self.duration = sound_file.duration
        self.num_channels = sound_file.n_channels
        if self.duration < self.cache_amount:
            self.mode = "short"
            self.signal, self.sr = librosa.load(self.path, sr=None)
            if len(self.signal.shape) == 1:
                self.signal = self.signal.reshape((self.signal.shape[0], 1))
            else:
                self.signal = self.signal.T
            self.preemph_signal = lfilter([1.0, -0.95], 1, self.signal, axis=0)
            self.downsampled_1000 = resample(self.signal, self.sr, 1000, filter="kaiser_fast", axis=0)
            self.downsampled_100 = resample(self.downsampled_1000, 1000, 100, filter="kaiser_fast", axis=0)
            self.cached_begin = 0
            self.cached_end = self.duration
        else:
            self.mode = "long"
            if initial_begin is not None:
                if initial_end is not None:
                    padding = self.cache_amount - (initial_end - initial_begin)
                    self.cached_begin = initial_begin - padding
                    self.cached_end = initial_end + padding
                else:
                    self.cached_begin = initial_begin - self.cache_amount / 2
                    self.cached_end = initial_begin + self.cache_amount / 2
                if self.cached_begin < 0:
                    self.cached_end -= self.cached_begin
                    self.cached_begin = 0
                if self.cached_end > self.duration:
                    diff = self.cached_end - self.duration
                    self.cached_end = self.duration
                    self.cached_begin -= diff
            else:
                self.cached_begin = 0
                self.cached_end = self.cache_amount
        self.refresh_cache()
Example #15
0
def waveform_to_examples(data, sample_rate, target_sample_rate=16000,
                         log_offset=0.01, stft_win_len_sec=0.025,
                         stft_hop_len_sec=0.010, num_mel_bins=64,
                         mel_min_hz=125, mel_max_hz=7500, frame_win_sec=0.96,
                         frame_hop_sec=0.96, **params):
  """Converts audio waveform into an array of examples for VGGish.

  Args:
    data: np.array of either one dimension (mono) or two dimensions
      (multi-channel, with the outer dimension representing channels).
      Each sample is generally expected to lie in the range [-1.0, +1.0],
      although this is not required.
    sample_rate: Sample rate of data.

  Returns:
    3-D np.array of shape [num_examples, num_frames, num_bands] which represents
    a sequence of examples, each of which contains a patch of log mel
    spectrogram, covering num_frames frames of audio and num_bands mel frequency
    bands, where the frame length is stft_hop_len_sec.
  """

  # Convert to mono.
  if len(data.shape) > 1:
    data = np.mean(data, axis=1)
  # Resample to the rate assumed by VGGish.
  if sample_rate != target_sample_rate:
    data = resampy.resample(data, sample_rate, target_sample_rate)

  # Compute log mel spectrogram features.
  log_mel = mel_features.log_mel_spectrogram(
      data,
      audio_sample_rate=target_sample_rate,
      log_offset=log_offset,
      window_length_secs=stft_win_len_sec,
      hop_length_secs=stft_hop_len_sec,
      num_mel_bins=num_mel_bins,
      lower_edge_hertz=mel_min_hz,
      upper_edge_hertz=mel_max_hz)

  # Frame features into examples.
  features_sample_rate = 1.0 / stft_hop_len_sec
  example_window_length = int(round(
      frame_win_sec * features_sample_rate))
  example_hop_length = int(round(
      frame_hop_sec * features_sample_rate))
  log_mel_examples = mel_features.frame(
      log_mel,
      window_length=example_window_length,
      hop_length=example_hop_length)

  return log_mel_examples
Example #16
0
def wavfile_to_waveform(wav_file):

    wav_data, sr = wav_read(wav_file)
    assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
    data = wav_data / 32768.0  # Convert to [-1.0, +1.0]

    if len(data.shape) > 1:
        data = np.mean(data, axis=1)
    # Resample to the rate assumed by VGGish.
    if sr != vggish_params.SAMPLE_RATE:
        data = resampy.resample(data, sr, vggish_params.SAMPLE_RATE)
        sr = vggish_params.SAMPLE_RATE

    return data, sr
Example #17
0
def waveform_to_examples(data, sample_rate):
    """Converts audio waveform into an array of examples for VGGish.

  Args:
    data: np.array of either one dimension (mono) or two dimensions
      (multi-channel, with the outer dimension representing channels).
      Each sample is generally expected to lie in the range [-1.0, +1.0],
      although this is not required.
    sample_rate: Sample rate of data.

  Returns:
    3-D np.array of shape [num_examples, num_frames, num_bands] which represents
    a sequence of examples, each of which contains a patch of log mel
    spectrogram, covering num_frames frames of audio and num_bands mel frequency
    bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
  """

    import resampy

    # Convert to mono.
    if len(data.shape) > 1:
        data = np.mean(data, axis=1)
    # Resample to the rate assumed by VGGish.
    if sample_rate != vggish_params.SAMPLE_RATE:
        data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)

    # Compute log mel spectrogram features.
    log_mel = mel_features.log_mel_spectrogram(
        data,
        audio_sample_rate=vggish_params.SAMPLE_RATE,
        log_offset=vggish_params.LOG_OFFSET,
        window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
        hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
        num_mel_bins=vggish_params.NUM_MEL_BINS,
        lower_edge_hertz=vggish_params.MEL_MIN_HZ,
        upper_edge_hertz=vggish_params.MEL_MAX_HZ,
    )

    # Frame features into examples.
    features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
    example_window_length = int(
        round(vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)
    )
    example_hop_length = int(
        round(vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)
    )
    log_mel_examples = mel_features.frame(
        log_mel, window_length=example_window_length, hop_length=example_hop_length
    )
    return log_mel_examples
Example #18
0
def resample_to_24000_hz(samples, input_rate):
    
    case = _24000_HZ_SPECIAL_CASES.get(float(input_rate))

    if case is not None:
        # input rate is a special case for which we can resample
        # efficiently using a polyphase filter
        
        up, down, filter_ = case
        # print('Resampling from {} Hz to 24000 Hz...'.format(input_rate))
        return signal.resample_poly(samples, up, down, window=filter_)
    
    else:
        return resampy.resample(samples, input_rate, 24000)
Example #19
0
def sync_with_pilot(seed, constl, ts_symbol_length, sample, sps):
    from resampy import resample
    np.random.seed(seed)
    constl = np.atleast_2d(constl)
    constl = constl[0]
    sample = np.atleast_2d(sample)

    index = np.random.randint(0, len(constl) - 1, (1, ts_symbol_length))
    pilot = np.empty((sample.shape[0], ts_symbol_length * sps))

    for i in pilot.shape[0]:
        pilot[i] = resample(constl[index], 1, sps)
    sample_after_sync = sync_correlation(pilot, sample, sps)
    return sample_after_sync
Example #20
0
    def waveform_to_examples(self, data, sample_rate):
        """Converts audio waveform into an array of examples for VGGish.

        Args:
          data: np.array of either one dimension (mono) or two dimensions
            (multi-channel, with the outer dimension representing channels).
            Each sample is generally expected to lie in the range [-1.0, +1.0],
            although this is not required.
          sample_rate: Sample rate of data.

        Returns:
          3-D np.array of shape [num_examples, num_frames, num_bands] which represents
          a sequence of examples, each of which contains a patch of log mel
          spectrogram, covering num_frames frames of audio and num_bands mel frequency
          bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
        """
        from .vggish_example_helper import mel_features
        import resampy

        # Convert to mono.
        print(type(data))
        if len(data.shape) > 1:
            data = np.mean(data, axis=1)
        # Resample to the rate assumed by VGGish.
        if sample_rate != self.sample_rate:
            data = resampy.resample(data, sample_rate, self.sample_rate)

        # Compute log mel spectrogram features.
        log_mel = mel_features.log_mel_spectrogram(
            data,
            audio_sample_rate=self.sample_rate,
            log_offset=self.log_offset,
            window_length_secs=self.stft_window_length_seconds,
            hop_length_secs=self.stft_hop_length_seconds,
            num_mel_bins=self.num_mel_binds,
            lower_edge_hertz=self.mel_min_hz,
            upper_edge_hertz=self.mel_max_hz)

        # Frame features into examples.
        features_sample_rate = 1.0 / self.stft_hop_length_seconds
        example_window_length = int(
            round(self.example_window_seconds * features_sample_rate))
        example_hop_length = int(
            round(self.example_hop_seconds * features_sample_rate))
        log_mel_examples = mel_features.frame(
            log_mel,
            window_length=example_window_length,
            hop_length=example_hop_length)
        return log_mel_examples
Example #21
0
def resample(y,
             orig_sr,
             target_sr,
             res_type='kaiser_best',
             fix=True,
             scale=False,
             **kwargs):
    # First, validate the audio buffer
    utils.valid_audio(y, mono=False)

    if orig_sr == target_sr:
        return y

    ratio = float(target_sr) / orig_sr

    n_samples = int(np.ceil(y.shape[-1] * ratio))

    if res_type in ('scipy', 'fft'):
        y_hat = scipy.signal.resample(y, n_samples, axis=-1)
    elif res_type == 'polyphase':
        if int(orig_sr) != orig_sr or int(target_sr) != target_sr:
            raise ParameterError(
                'polyphase resampling is only supported for integer-valued sampling rates.'
            )

        # For polyphase resampling, we need up- and down-sampling ratios
        # We can get those from the greatest common divisor of the rates
        # as long as the rates are integrable
        orig_sr = int(orig_sr)
        target_sr = int(target_sr)
        gcd = np.gcd(orig_sr, target_sr)
        y_hat = scipy.signal.resample_poly(y,
                                           target_sr // gcd,
                                           orig_sr // gcd,
                                           axis=-1)
    else:
        y_hat = resampy.resample(y,
                                 orig_sr,
                                 target_sr,
                                 filter=res_type,
                                 axis=-1)

    if fix:
        y_hat = utils.fix_length(y_hat, n_samples, **kwargs)

    if scale:
        y_hat /= np.sqrt(ratio)

    return np.ascontiguousarray(y_hat, dtype=y.dtype)
Example #22
0
def waveform_to_examples(data, sample_rate):
    """Converts audio waveform into an array of examples for VGGish.

    Args:
      data: np.array of either one dimension (mono) or two dimensions
        (multi-channel, with the outer dimension representing channels).
        Each sample is generally expected to lie in the range [-1.0, +1.0],
        although this is not required.
      sample_rate: Sample rate of data.

    Returns:
      3-D np.array of shape [num_examples, num_frames, num_bands] which represents
      a sequence of examples, each of which contains a patch of log mel
      spectrogram, covering num_frames frames of audio and num_bands mel frequency
      bands, where the frame length is params.STFT_HOP_LENGTH_SECONDS.
    """
    # Convert to mono.
    if len(data.shape) > 1:
        data = np.mean(data, axis=1)
    # Resample to the rate assumed by VGGish.
    if sample_rate != params.SAMPLE_RATE:
        data = resampy.resample(data, sample_rate, params.SAMPLE_RATE)

    # Compute log mel spectrogram features.
    log_mel = mel_features.log_mel_spectrogram(
        data,
        audio_sample_rate=params.SAMPLE_RATE,  # 16000
        log_offset=params.
        LOG_OFFSET,  # 0.01 Offset used for stabilized log of input mel-spectrogram.
        window_length_secs=params.STFT_WINDOW_LENGTH_SECONDS,  # 25ms
        hop_length_secs=params.STFT_HOP_LENGTH_SECONDS,  # 10ms
        num_mel_bins=params.
        NUM_MEL_BINS,  # 64 # Frequency bands in input mel-spectrogram patch.
        lower_edge_hertz=params.MEL_MIN_HZ,  # 125Hz
        upper_edge_hertz=params.MEL_MAX_HZ)  # 7500Hz

    # Frame features into examples.
    features_sample_rate = 1.0 / params.STFT_HOP_LENGTH_SECONDS  # 100 frame/s
    example_window_length = int(
        round(params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)
    )  # 0.96  # Each example contains 96 10ms frames
    example_hop_length = int(
        round(params.EXAMPLE_HOP_SECONDS *
              features_sample_rate))  # 0.96 # with zero overlap.
    log_mel_examples = mel_features.frame(  # [3, 96, 64] change log-mel to batch whose length is 1s
        log_mel,
        window_length=example_window_length,  # 96
        hop_length=example_hop_length)  # 96
    return log_mel_examples
Example #23
0
 def recognize_array(self, frames, sr):
     t = time.time()
     if decoder.CONVERT_TO_MONO:
         frames = np.array([np.mean(frames, axis=0)], dtype=frames.dtype)
     if decoder.RESAMPLE and sr != fingerprint.DEFAULT_FS and len(frames[-1]) > 0:
         frames = resample(frames, sr, fingerprint.DEFAULT_FS, axis=-1)
         self.Fs = fingerprint.DEFAULT_FS
     if decoder.NORMALIZE and len(frames[-1]) > 0:
         gain = (-np.iinfo(frames.dtype).min) / np.max(np.abs(frames))
         frames = np.array(frames * gain, dtype=frames.dtype)
     match = self._recognize(*frames)
     t = time.time() - t
     if match:
         match['match_time'] = t
     return match
Example #24
0
def read_wav(fname, output_sr, use_rosa=False):

    if use_rosa:
        waveform, sr = librosa.load(fname, sr=output_sr)
    else:
        wav_data, sr = sf.read(fname, dtype=np.int16)

        if wav_data.ndim > 1:
            # (ns, 2)
            wav_data = wav_data.mean(1)
        if sr != output_sr:
            wav_data = resampy.resample(wav_data, sr, output_sr)
        waveform = wav_data / 32768.0

    return waveform.astype(np.float64)
Example #25
0
def read_wav(filename, target_sample_rate=16000, verbose=False):
    rate, data = wavfile.read(filename)
    assert (data.dtype == np.int16)

    data = np.float32(data) / 32768.0

    # stereo -> mono
    if len(data.shape) > 1:
        data = np.mean(data, axis=1)

    # Resample to the correct sample rate if necessary
    if rate != target_sample_rate:
        data = resampy.resample(data, rate, target_sample_rate)

    return data
Example #26
0
def test_nnresample():
    """ Compare matlab and nnresample resample : FAILING """
    from nnresample import resample
    from pystoi.stoi import FS
    import matlab_wrapper
    matlab = matlab_wrapper.MatlabSession()
    matlab.put('FS', float(FS))
    RTOL = 1e-4
    for fs in [8000, 11025, 16000, 22050, 32000, 44100, 48000]:
        x = np.random.randn(2*fs,)
        x_r = resample(x, FS, fs)
        matlab.put('x', x)
        matlab.put('fs', float(fs))
        matlab.eval('x_r = resample(x, FS, fs)')
        assert_allclose(x_r, matlab.get('x_r'), atol=ATOL, rtol=RTOL)
Example #27
0
    def resample(self, new_fs):
        """
        Resample time series to ``new_fs``.
        """
        if self.start_time != 0:
            raise (NotImplementedError(
                'The method resample is implemented only for time series '
                'objects with start_time equal to 0.'))
        if self.fs == new_fs:
            return self
        new = self.copy()
        new.data = resampy.resample(self.data, float(self.fs), float(new_fs))
        new.fs = new_fs

        return new
Example #28
0
 def handle(self, data):
     in_data, frame_count, time_info, status, divisor = data
     try:
         decoded = np.frombuffer(in_data, dtype=np.int16) / divisor
         decimated = resampy.resample(decoded,
                                      self.fs,
                                      self.new_fs,
                                      filter=load_resampy_filter())
         logger.info(
             f"Decimated {decoded.size} to {decimated.size} samples")
         self.protocol.sendLine(
             f"DAT|{'|'.join([str(d) for d in decimated])}".encode())
     except:
         logger.exception(
             f"Unserialisable data type {data.__class__.__name__}")
Example #29
0
def yamnet_transform(waveform, sample_rate):
    '''
    Args:
        waveform: np tsr [num_steps, num_channels]
        sample_rate: per second sample rate
    '''
    import tensorflow as tf
    # tf.enable_eager_execution()
    data = waveform.mean(axis=0)
    if sample_rate != YAMNetParams.SAMPLE_RATE:
        data = resampy.resample(data, sample_rate, VGGishParams.SAMPLE_RATE)
    spectrogram = features_lib.waveform_to_log_mel_spectrogram(
        data, YAMNetParams)
    patches = features_lib.spectrogram_to_patches(spectrogram, YAMNetParams)
    return patches
Example #30
0
def resample_data(y, orig_sr, target_sr, **kwargs):
    if orig_sr == target_sr:
        return y

    ratio = float(target_sr) / orig_sr
    n_samples = int(np.ceil(y.shape[-1] * ratio))

    y_hat = resampy.resample(y,
                             orig_sr,
                             target_sr,
                             filter='kaiser_best',
                             axis=-1)
    y_hat = fix_length(y_hat, n_samples, **kwargs)

    return np.ascontiguousarray(y_hat, dtype=y.dtype)
def song_downsampler(input_song,
                     n_out_channels=256,
                     input_freq=44100,
                     output_freq=8000):
    '''
    Uses resampys downsampler to convert to lower number of data points without ruining the
    audio file
    '''

    ds_ch = resampy.resample(input_song.astype(np.float), input_freq,
                             output_freq)

    ds_song = ds_ch.transpose()

    return ds_song
Example #32
0
def augment_audio_signal(signal, sample_freq, augmentation):
    """Function that performs audio signal augmentation.

  Args:
    signal (np.array): np.array containing raw audio signal.
    sample_freq (float): frames per second.
    augmentation (dict, optional): None or dictionary of augmentation parameters.
        If not None, has to have 'speed_perturbation_ratio',
        'noise_level_min', or 'noise_level_max' fields, e.g.::
          augmentation={
            'speed_perturbation_ratio': 0.2,
            'noise_level_min': -90,
            'noise_level_max': -46,
          }
        'speed_perturbation_ratio' can either be a list of possible speed
        perturbation factors or a float. If float, a random value from 
        U[1-speed_perturbation_ratio, 1+speed_perturbation_ratio].
  Returns:
    np.array: np.array with augmented audio signal.
  """
    signal_float = normalize_signal(signal.astype(np.float32))

    if 'speed_perturbation_ratio' in augmentation:
        stretch_amount = -1
        if isinstance(augmentation['speed_perturbation_ratio'], list):
            stretch_amount = np.random.choice(
                augmentation['speed_perturbation_ratio'])
        elif augmentation['speed_perturbation_ratio'] > 0:
            # time stretch (might be slow)
            stretch_amount = 1.0 + (2.0 * np.random.rand() - 1.0) * \
                             augmentation['speed_perturbation_ratio']
        if stretch_amount > 0:
            signal_float = rs.resample(
                signal_float,
                sample_freq,
                int(sample_freq * stretch_amount),
                filter='kaiser_best',
            )

    # noise
    if 'noise_level_min' in augmentation and 'noise_level_max' in augmentation:
        noise_level_db = np.random.randint(
            low=augmentation['noise_level_min'],
            high=augmentation['noise_level_max'])
        signal_float += np.random.randn(signal_float.shape[0]) * \
                        10.0 ** (noise_level_db / 20.0)

    return normalize_signal(signal_float)
Example #33
0
def preprocess_audio_batch(audio, sr, center=True, hop_size=0.1, sampler="julian"):
    if audio.ndim == 3:
        audio = torch.mean(audio, axis=2)

    if sr != TARGET_SR:
        if sampler == "julian":
            audio = julius.resample_frac(audio, sr, TARGET_SR)

        elif sampler == "resampy":
            audio = torch.tensor(
                resampy.resample(
                    audio.detach().cpu().numpy(),
                    sr_orig=sr,
                    sr_new=TARGET_SR,
                    filter="kaiser_best",
                ),
                dtype=audio.dtype,
                device=audio.device,
            )

        else:
            raise ValueError("Only julian and resampy works!")

    frame_len = TARGET_SR
    hop_len = int(hop_size * TARGET_SR)
    if center:
        audio = center_audio(audio, frame_len)

    audio = pad_audio(audio, frame_len, hop_len)
    n_frames = 1 + int((audio.size()[1] - frame_len) / float(hop_len))
    x = []
    xframes_shape = None
    for i in range(audio.shape[0]):
        xframes = (
            torch.as_strided(
                audio[i],
                size=(frame_len, n_frames),
                stride=(1, hop_len),
            )
            .transpose(0, 1)
            .unsqueeze(1)
        )
        if xframes_shape is None:
            xframes_shape = xframes.shape
        assert xframes.shape == xframes_shape
        x.append(xframes)
    x = torch.vstack(x)
    return x
Example #34
0
def get_frames(audio, sr, center=True, step_size=10, normalize: bool = True):
    """Split the provided audio into frames

    Parameters
    ----------
    audio : np.ndarray [shape=(N,) or (N, C)]
        The audio samples. Multichannel audio will be downmixed.
    sr : int
        Sample rate of the audio samples. The audio will be resampled if
        the sample rate is not 16 kHz, which is expected by the model.
    center : boolean
        - If `True` (default), the signal `audio` is padded so that frame
          `D[:, t]` is centered at `audio[t * hop_length]`.
        - If `False`, then `D[:, t]` begins at `audio[t * hop_length]`
    step_size : int
        The step size in milliseconds for running pitch estimation.

    Returns
    -------
    frames : np.ndarray [shape=(T, 1024)]
    """
    if len(audio.shape) == 2:
        audio = audio.mean(1)  # make mono
    audio = audio.astype(np.float32)
    if sr != model_srate:
        # resample audio if necessary
        from resampy import resample
        audio = resample(audio, sr, model_srate)

    # pad so that frames are centered around their timestamps (i.e. first frame
    # is zero centered).
    if center:
        audio = np.pad(audio, 512, mode='constant', constant_values=0)

    # make 1024-sample frames of the audio with hop length of 10 milliseconds
    hop_length = int(model_srate * step_size / 1000)
    n_frames = 1 + int((len(audio) - 1024) / hop_length)
    frames = as_strided(audio,
                        shape=(1024, n_frames),
                        strides=(audio.itemsize, hop_length * audio.itemsize),
                        writeable=False).copy()
    frames = frames.transpose()

    if normalize:
        # normalize each frame -- this is expected by the model
        frames = frames - frames.mean(axis=-1, keepdims=True)
        frames = frames / (frames.std(axis=-1, keepdims=True) + 1e-6)
    return frames
def use_audio_path_specified_audio(wav_path,
                                   wav_word,
                                   rms_normalize=None,
                                   SR=20000):
    """
  Loads an example wav specified by wav_path

  Inputs: 
   wav_path (string) : filepath to the audio to load
   wav_word (string) : label for the audio in wav_path
   rms_normalize (float) : the rms value to set the audio to 
   SR (int) : sampling rate of the desired audio. The file at 
     wav_path will be resampled to this value

  Output: 
    audio_dict (dictionary) : a dictionary containing the loaded 
      audio and preprocessing parameters
  """

    metamer_word_encodings = pickle.load(
        open('assets/metamer_word_encodings.pckl', 'rb'))
    word_to_int = metamer_word_encodings['word_to_word_idx']

    print("Loading: %s" % wav_path)
    SR_loaded, wav_f = scipy.io.wavfile.read(wav_path)
    if SR_loaded != SR:
        wav_f = resampy.resample(wav_f, SR_loaded, SR)
        SR_loaded = SR

    if rms_normalize is not None:
        wav_f = wav_f - np.mean(wav_f.ravel())
        wav_f = wav_f / (np.sqrt(np.mean(wav_f.ravel()**2))) * rms_normalize
        rms = rms_normalize
    else:
        rms = np.sqrt(np.mean(wav_f.ravel()**2))

    audio_dict = {}

    audio_dict['wav'] = wav_f
    audio_dict['SR'] = SR
    audio_dict['word_int'] = word_to_int[wav_word]
    audio_dict['word'] = wav_word
    audio_dict['rms'] = rms
    audio_dict['filename'] = wav_path
    audio_dict['filename_short'] = wav_path.split('/')[-1]
    audio_dict['correct_response'] = wav_word

    return audio_dict
Example #36
0
    def _process_input_chunk(self, samples):

        input_length = len(samples)

        if self._classifier_sample_rate != self._input_sample_rate:

            # start_time = time.time()

            samples = resampy.resample(samples,
                                       self._input_sample_rate,
                                       self._classifier_sample_rate,
                                       filter='kaiser_fast')

            # processing_time = time.time() - start_time
            # input_duration = input_length / self._input_sample_rate
            # rate = input_duration / processing_time
            # print((
            #     'Resampled {:.1f} seconds of input in {:.1f} seconds, '
            #     'or {:.1f} times faster than real time.').format(
            #         input_duration, processing_time, rate))

        self._waveforms = _get_analysis_records(
            samples, self._classifier_waveform_length, self._hop_size)

        #         print('Scoring chunk waveforms...')
        #         start_time = time.time()

        scores = classifier_utils.score_dataset_examples(
            self._estimator, self._create_dataset)

        #         elapsed_time = time.time() - start_time
        #         num_waveforms = self._waveforms.shape[0]
        #         rate = num_waveforms / elapsed_time
        #         print((
        #             'Scored {} waveforms in {:.1f} seconds, a rate of {:.1f} '
        #             'waveforms per second.').format(
        #                 num_waveforms, elapsed_time, rate))

        if _SCORE_OUTPUT_ENABLED:
            self._score_file_writer.write(samples, scores)

        for threshold in self._thresholds:
            peak_indices = signal_utils.find_peaks(scores, threshold)
            peak_scores = scores[peak_indices]
            self._notify_listener_of_clips(peak_indices, peak_scores,
                                           input_length, threshold)

        self._input_chunk_start_index += input_length
Example #37
0
    def resample(self, target_sample_rate, filter='kaiser_best'):
        """Resample the audio to a target sample rate.

        Note that this is an in-place transformation.

        :param target_sample_rate: Target sample rate.
        :type target_sample_rate: int
        :param filter: The resampling filter to use one of {'kaiser_best',
                       'kaiser_fast'}.
        :type filter: str
        """
        self._samples = resampy.resample(self.samples,
                                         self.sample_rate,
                                         target_sample_rate,
                                         filter=filter)
        self._sample_rate = target_sample_rate
Example #38
0
def test(table, old_sr, new_sr, device="cpu"):
    x = th.randn(16, 8 * old_sr * int(math.ceil(44_100 / old_sr)), device=device)

    with Chrono() as chrono:
        y = resample_frac(x, old_sr, new_sr, zeros=56)
    dur_julius = int(1000 * chrono.duration)

    if device == "cpu":
        with Chrono() as chrono:
            y_resampy = th.from_numpy(resampy.resample(x.numpy(), old_sr, new_sr))
        dur_resampy = int(1000 * chrono.duration)

        delta = (y_resampy - y).abs().mean()
        table.line([old_sr, new_sr, dur_julius, dur_resampy, format(delta, ".1%")])
    else:
        table.line([old_sr, new_sr, dur_julius])
Example #39
0
def test_resample(input, input_sample_rate, output_sample_rate, method, axis):
    t = transforms.Resample(input_sample_rate, output_sample_rate,
                            method=method, axis=axis)
    output_length = int(input.shape[axis] * output_sample_rate
                        / float(input_sample_rate))
    print(input.shape)
    if method == 'scipy':
        expected_output = scipy.signal.resample(input, output_length,
                                                axis=axis)
    else:
        expected_output = resampy.resample(input, input_sample_rate,
                                           output_sample_rate, method=method,
                                           axis=axis)
    transformed_input = t(input)
    assert transformed_input.shape[axis] == output_length
    assert np.array_equal(transformed_input, expected_output)
Example #40
0
    def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """
        Resample `x` to a new sampling rate.

        :param x: Sample to resample of shape `(batch_size, length, channel)` or `(batch_size, channel, length)`.
        :param y: Labels of the sample `x`. This function does not affect them in any way.
        :return: Resampled audio sample.
        """
        import resampy

        if x.ndim != 3:
            raise ValueError("Resampling can only be applied to temporal data across at least one channel.")

        sample_index = 2 if self.channels_first else 1

        return resampy.resample(x, self.sr_original, self.sr_new, axis=sample_index, filter="sinc_window"), y
Example #41
0
def wav_stream(files):
    for file in files:
        srate, data = wavfile.read(os.path.join(args.input_path, file))
        if len(data.shape) == 2:
            data = data.mean(axis=1)
        if srate != 16000:
            data = resample(data, srate, 16000)
            srate = 16000
        hop_length = int(srate / 100)
        n_frames = 1 + int((len(data) - 1024) / hop_length)
        frames = as_strided(data,
                            shape=(1024, n_frames),
                            strides=(data.itemsize,
                                     hop_length * data.itemsize))
        frames = frames.transpose().astype(np.float32)
        yield (file, frames)
def audio_to_midi_melodia(infile, outfile, bpm, smooth=0.25, minduration=0.1):
    # define analysis parameters
    fs = 44100
    hop = 128

    # load audio using librosa
    print("Loading audio...")
    data, sr = soundfile.read(infile)
    # mixdown to mono if needed
    if len(data.shape) > 1 and data.shape[1] > 1:
        data = data.mean(axis=1)
    # resample to 44100 if needed
    if sr != fs:
        data = resampy.resample(data, sr, fs)
        sr = fs

    # extract melody using melodia vamp plugin
    print("Extracting melody f0 with MELODIA...")
    melody = vamp.collect(data,
                          sr,
                          "mtg-melodia:melodia",
                          parameters={"voicing": 0.2})

    # hop = melody['vector'][0]
    pitch = melody['vector'][1]

    # impute missing 0's to compensate for starting timestamp
    pitch = np.insert(pitch, 0, [0] * 8)

    # debug
    # np.asarray(pitch).dump('f0.npy')
    # print(len(pitch))

    # convert f0 to midi notes
    print("Converting Hz to MIDI notes...")
    midi_pitch = hz2midi(pitch)

    # segment sequence into individual midi notes
    notes = midi_to_notes(midi_pitch, fs, hop, smooth, minduration)

    # print notes

    # save note sequence to a midi file
    print("Saving MIDI to disk...")
    save_midi(outfile, notes, bpm)

    print("Conversion complete.")
Example #43
0
def annotated_spectrogram(audio_file,
                          annotations_file,
                          fs_spec=16000,
                          figsize=(10, 5),
                          height_ratios=[3, 1]):

    n = np.genfromtxt(annotations_file, dtype=np.str)
    unique_labels = np.unique(n[:, 2])

    audio, fs = sf.read(audio_file)

    # extract mono from multichannel audio
    if audio.ndim > 1 and np.min(audio.shape) > 1:
        if np.where(audio.shape == np.min(audio.shape))[0][0] == 0:
            audio = audio[0, :]
        else:
            audio = audio[:, 0]

    audio_lowres = resampy.resample(audio, fs, fs_spec)
    f, t, sxx = spectrogram(audio_lowres, fs_spec, nperseg=1024)

    gs = gridspec.GridSpec(2, 1, height_ratios=height_ratios)
    gs.update(hspace=0.1)

    fig = plt.figure(figsize=figsize)
    ax = plt.subplot(gs[0])
    plt.xlim([0, 60])
    plt.ylabel('Frequency (Hz)')
    plt.xticks([])
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.pcolormesh(t, f, np.log10(sxx), cmap='GnBu')

    ax2 = plt.subplot(gs[1])
    for i, label in enumerate(unique_labels):
        # get list of start/stop times for each label
        a = n[:, 0][n[:, 2] == label].astype(float)
        b = n[:, 1][n[:, 2] == label].astype(float)
        ax2.hlines(np.zeros(len(a)) + i, a, b, colors=plt.cm.tab20(i), lw=10)

    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)

    plt.xlim([0, 60])
    plt.yticks(np.arange(i + 1), unique_labels)
    plt.xlabel('Time (seconds)')
Example #44
0
File: core.py Project: jdasam/crepe
def predict_offline(model,
                    audio,
                    sr,
                    center=True,
                    step_size=10,
                    viterbi=False):
    if len(audio.shape) == 2:
        audio = audio.mean(1)  # make mono
    audio = audio.astype(np.float32)
    if sr != model_srate:
        # resample audio if necessary
        from resampy import resample
        audio = resample(audio, sr, model_srate)

    # pad so that frames are centered around their timestamps (i.e. first frame
    # is zero centered).
    if center:
        audio = np.pad(audio, 512, mode='constant', constant_values=0)

    # make 1024-sample frames of the audio with hop length of 10 milliseconds
    hop_length = int(model_srate * step_size / 1000)
    n_frames = 1 + int((len(audio) - 1024) / hop_length)
    frames = as_strided(audio,
                        shape=(1024, n_frames),
                        strides=(audio.itemsize, hop_length * audio.itemsize))
    frames = frames.transpose().copy()

    # normalize each frame -- this is expected by the model
    frames -= np.mean(frames, axis=1)[:, np.newaxis]
    frames /= np.std(frames, axis=1)[:, np.newaxis]

    # run prediction and convert the frequency bin weights to Hz
    activation = model.predict(frames, verbose=0)
    confidence = activation.max(axis=1)

    if viterbi:
        cents = to_viterbi_cents(activation)
    else:
        cents = to_local_average_cents(activation)

    frequency = 10 * 2**(cents / 1200)
    frequency[np.isnan(frequency)] = 0

    time = np.arange(confidence.shape[0]) * step_size / 1000.0

    return frequency, confidence
Example #45
0
def main(argv):
  assert argv, 'Usage: inference.py <wav file> <wav file> ...'

  model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'yamnet.h5')
  classes_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'yamnet_class_map.csv')
  event_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'event.json')
  
  params = yamnet_params.Params()
  yamnet = yamnet_model.yamnet_frames_model(params)
  yamnet.load_weights(model_path)
  yamnet_classes = yamnet_model.class_names(classes_path)

  for file_name in argv:
    # Decode the WAV file.
    wav_data, sr = sf.read(file_name, dtype=np.int16)
    assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
    waveform = wav_data / 32768.0  # Convert to [-1.0, +1.0]
    waveform = waveform.astype('float32')

    # Convert to mono and the sample rate expected by YAMNet.
    if len(waveform.shape) > 1:
      waveform = np.mean(waveform, axis=1)
    if sr != params.sample_rate:
      waveform = resampy.resample(waveform, sr, params.sample_rate)

    # Predict YAMNet classes.
    scores, embeddings, spectrogram = yamnet(waveform)
    # Scores is a matrix of (time_frames, num_classes) classifier scores.
    # Average them along time to get an overall classifier output for the clip.
    prediction = np.mean(scores, axis=0)
    # Report the highest-scoring classes and their scores.
    top5_i = np.argsort(prediction)[::-1][:5]
    print(file_name, ':\n' +
          '\n'.join('  {:12s}: {:.3f}'.format(yamnet_classes[i], prediction[i])
                    for i in top5_i))
    
    # print all classes
    b = prediction.tolist() # nested lists with same data, indices
    pred = []
    for (i,cls) in enumerate(yamnet_classes):
      item={}
      item['label']=cls
      item['value']=round(b[i], 6)
      pred.append(item)
    pred = sorted(pred, key=lambda x: x['value'], reverse=True)
    json.dump(pred, codecs.open(event_path, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=True, indent=4) ### this saves the array in .json format
Example #46
0
def pitchBend(frame, pitch_to_bend):
    if pitch_to_bend == 0:
        return frame
    freq_oitar = np.exp(-pitch_to_bend * 0.057762265046662105)
    # The inverse of 'ratio'
    frame = resample(frame, SR, SR * freq_oitar, filter=FILTER)
    left = frame[:CROSS_FADE_TAILS]
    left_mid = frame[CROSS_FADE_TAILS:CROSS_FADE_TAILS + CROSS_FADE_OVERLAP]
    right_mid = frame[-CROSS_FADE_TAILS - CROSS_FADE_OVERLAP:-CROSS_FADE_TAILS]
    right = frame[-CROSS_FADE_TAILS:]
    frame = np.concatenate((
        left,
        np.multiply(left_mid, FADE_OUT_WINDOW) +
        np.multiply(right_mid, FADE_IN_WINDOW),
        right,
    ))
    return frame
Example #47
0
def waveform_to_examples(data, sample_rate):
    """Converts audio waveform into an array of examples for VGGish.

    Args:
      data: np.array of either one dimension (mono) or two dimensions
        (multi-channel, with the outer dimension representing channels).
        Each sample is generally expected to lie in the range [-1.0, +1.0],
        although this is not required.
      sample_rate: Sample rate of data.

    Returns:
      3-D np.array of shape [num_examples, num_frames, num_bands] which represents
      a sequence of examples, each of which contains a patch of log mel
      spectrogram, covering num_frames frames of audio and num_bands mel frequency
      bands, where the frame length is params.STFT_HOP_LENGTH_SECONDS.
    """
    # Convert to mono.
    if len(data.shape) > 1:
        data = np.mean(data, axis=1)
    # Resample to the rate assumed by VGGish.
    if sample_rate != params.SAMPLE_RATE:
        data = resampy.resample(data, sample_rate, params.SAMPLE_RATE)

    # Compute log mel spectrogram features.
    log_mel = mel_features.log_mel_spectrogram(
        data,
        audio_sample_rate=params.SAMPLE_RATE,
        log_offset=params.LOG_OFFSET,
        window_length_secs=params.STFT_WINDOW_LENGTH_SECONDS,
        hop_length_secs=params.STFT_HOP_LENGTH_SECONDS,
        num_mel_bins=params.NUM_MEL_BINS,
        lower_edge_hertz=params.MEL_MIN_HZ,
        upper_edge_hertz=params.MEL_MAX_HZ)

    # Frame features into examples.
    features_sample_rate = 1.0 / params.STFT_HOP_LENGTH_SECONDS
    example_window_length = int(round(
        params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
    example_hop_length = int(round(
        params.EXAMPLE_HOP_SECONDS * features_sample_rate))
    log_mel_examples = mel_features.frame(
        log_mel,
        window_length=example_window_length,
        hop_length=example_hop_length)
    return log_mel_examples
Example #48
0
    def _get_clip_samples(self, clip):
         
        clip_sample_rate = clip.sample_rate
        classifier_sample_rate = self._settings.waveform_sample_rate

        s2f = signal_utils.seconds_to_frames
        start_offset = s2f(self._waveform_start_time, clip_sample_rate)
        
        if clip_sample_rate != classifier_sample_rate:
            # need to resample
            
            # Get clip samples, including a millisecond of padding at
            # the end. I don't know what if any guarantees the
            # `resampy.resample` function offers about the relationship
            # between its input and output lengths, so we add the padding
            # to try to ensure that we don't wind up with too few samples
            # after resampling.
            length = s2f(self._waveform_duration + .001, clip_sample_rate)
            samples = self._clip_manager.get_samples(
                clip, start_offset=start_offset, length=length)
            
            # Resample clip samples to classifier sample rate.
            samples = resampy.resample(
                samples, clip_sample_rate, classifier_sample_rate)
            
            # Discard any extra trailing samples we wound up with.
            samples = samples[:self._waveform_length]
            
            if len(samples) < self._waveform_length:
                raise ValueError('Resampling produced too few samples.')
            
        else:
            # don't need to resample
            
            samples = self._clip_manager.get_samples(
                clip, start_offset=start_offset, length=self._waveform_length)
             
        return samples
Example #49
0
    def _process_input_chunk(self, samples):
        
        input_length = len(samples)
        
        if self._classifier_sample_rate != self._input_sample_rate:
             
            samples = resampy.resample(
                samples, self._input_sample_rate, self._classifier_sample_rate,
                filter='kaiser_fast')
            
        self._waveforms = _get_analysis_records(
            samples, self._classifier_waveform_length, self._hop_size)
        
#         print('Scoring chunk waveforms...')
#         start_time = time.time()
         
        scores = classifier_utils.score_dataset_examples(
            self._estimator, self._create_dataset)
        
#         elapsed_time = time.time() - start_time
#         num_waveforms = self._waveforms.shape[0]
#         rate = num_waveforms / elapsed_time
#         print((
#             'Scored {} waveforms in {:.1f} seconds, a rate of {:.1f} '
#             'waveforms per second.').format(
#                 num_waveforms, elapsed_time, rate))
        
        if _SCORE_OUTPUT_ENABLED:
            self._score_file_writer.write(samples, scores)
         
        for threshold in self._thresholds:
            peak_indices = signal_utils.find_peaks(scores, threshold)
            peak_scores = scores[peak_indices]
            self._notify_listener_of_clips(
                peak_indices, peak_scores, input_length, threshold)
        
        self._input_chunk_start_index += input_length
Example #50
0
def resample(y, orig_sr, target_sr, res_type='kaiser_best', fix=True, scale=False, **kwargs):
    """Resample a time series from orig_sr to target_sr

    Parameters
    ----------
    y : np.ndarray [shape=(n,) or shape=(2, n)]
        audio time series.  Can be mono or stereo.

    orig_sr : number > 0 [scalar]
        original sampling rate of `y`

    target_sr : number > 0 [scalar]
        target sampling rate

    res_type : str
        resample type (see note)

        .. note::
            By default, this uses `resampy`'s high-quality mode ('kaiser_best').

            To use `scipy.signal.resample`, set `res_type='scipy'`.

    fix : bool
        adjust the length of the resampled signal to be of size exactly
        `ceil(target_sr * len(y) / orig_sr)`

    scale : bool
        Scale the resampled signal so that `y` and `y_hat` have approximately
        equal total energy.

    kwargs : additional keyword arguments
        If `fix==True`, additional keyword arguments to pass to
        `librosa.util.fix_length`.

    Returns
    -------
    y_hat : np.ndarray [shape=(n * target_sr / orig_sr,)]
        `y` resampled from `orig_sr` to `target_sr`


    See Also
    --------
    librosa.util.fix_length
    scipy.signal.resample
    resampy.resample

    Examples
    --------
    Downsample from 22 KHz to 8 KHz

    >>> y, sr = librosa.load(librosa.util.example_audio_file(), sr=22050)
    >>> y_8k = librosa.resample(y, sr, 8000)
    >>> y.shape, y_8k.shape
    ((1355168,), (491671,))

    """

    # First, validate the audio buffer
    util.valid_audio(y, mono=False)

    if orig_sr == target_sr:
        return y

    ratio = float(target_sr) / orig_sr

    n_samples = int(np.ceil(y.shape[-1] * ratio))

    if res_type == 'scipy':
        y_hat = scipy.signal.resample(y, n_samples, axis=-1)
    else:
        y_hat = resampy.resample(y, orig_sr, target_sr, filter=res_type, axis=-1)

    if fix:
        y_hat = util.fix_length(y_hat, n_samples, **kwargs)

    if scale:
        y_hat /= np.sqrt(ratio)

    return np.ascontiguousarray(y_hat, dtype=y.dtype)
Example #51
0
def test_short_signal():

    x = np.zeros(2)
    resampy.resample(x, 4, 1)
Example #52
0
def test_dtype(dtype):
    x = np.random.randn(100).astype(dtype)

    y = resampy.resample(x, 100, 200)

    assert x.dtype == y.dtype
Example #53
0
def test_bad_window():
    x = np.zeros(100)

    resampy.resample(x, 100, 200, filter='sinc_window', window=np.ones(50))
Example #54
0
def resample(y, orig_sr, target_sr, res_type='kaiser_best', fix=True, scale=False, **kwargs):
    """Resample a time series from orig_sr to target_sr

    Parameters
    ----------
    y : np.ndarray [shape=(n,) or shape=(2, n)]
        audio time series.  Can be mono or stereo.

    orig_sr : number > 0 [scalar]
        original sampling rate of `y`

    target_sr : number > 0 [scalar]
        target sampling rate

    res_type : str
        resample type (see note)

        .. note::
            By default, this uses `resampy`'s high-quality mode ('kaiser_best').
            If `res_type` is not recognized by `resampy.resample`, it then
            falls back on `scikits.samplerate` (if it is installed)

            If both of those fail, it will fall back on `scipy.signal.resample`.

            To force use of `scipy.signal.resample`, set `res_type='scipy'`.

    fix : bool
        adjust the length of the resampled signal to be of size exactly
        `ceil(target_sr * len(y) / orig_sr)`

    scale : bool
        Scale the resampled signal so that `y` and `y_hat` have approximately
        equal total energy.

    kwargs : additional keyword arguments
        If `fix==True`, additional keyword arguments to pass to
        `librosa.util.fix_length`.

    Returns
    -------
    y_hat : np.ndarray [shape=(n * target_sr / orig_sr,)]
        `y` resampled from `orig_sr` to `target_sr`


    See Also
    --------
    librosa.util.fix_length
    scipy.signal.resample

    Examples
    --------
    Downsample from 22 KHz to 8 KHz

    >>> y, sr = librosa.load(librosa.util.example_audio_file(), sr=22050)
    >>> y_8k = librosa.resample(y, sr, 8000)
    >>> y.shape, y_8k.shape
    ((1355168,), (491671,))

    """

    # First, validate the audio buffer
    util.valid_audio(y, mono=False)

    if orig_sr == target_sr:
        return y

    ratio = float(target_sr) / orig_sr

    n_samples = int(np.ceil(y.shape[-1] * ratio))

    try:
        y_hat = resampy.resample(y, orig_sr, target_sr, filter=res_type, axis=-1)
    except NotImplementedError:
        if _HAS_SAMPLERATE and (res_type != 'scipy'):
            warnings.warn('scikits.samplerate resampling is deprecated as '
                          'of librosa version 0.4.3.\n\tSupport will be '
                          'removed in librosa version 0.5.',
                          category=DeprecationWarning)
            y_hat = samplerate.resample(y.T, ratio, res_type).T
        else:
            y_hat = scipy.signal.resample(y, n_samples, axis=-1)

    if fix:
        y_hat = util.fix_length(y_hat, n_samples, **kwargs)

    if scale:
        y_hat /= np.sqrt(ratio)

    return np.ascontiguousarray(y_hat, dtype=y.dtype)
Example #55
0
def test_bad_num_zeros():
    x = np.zeros(100)
    resampy.resample(x, 100, 50, filter='sinc_window', num_zeros=0)
Example #56
0
def test_bad_precision():
    x = np.zeros(100)
    resampy.resample(x, 100, 50, filter='sinc_window', precision=-1)
Example #57
0
def get_feature(audio_list, spk_to_idx, min_mix=2, max_mix=2, batch_size=1):
    """
    :param audio_list: audio file list
        path/to/1st.wav spk1
        path/to/2nd.wav spk2
        path/to/3rd.wav spk1
    :param spk_to_idx: dict, spk1:0, spk2:1, ...
    :param min_mix:
    :param max_mix:
    :param batch_size:
    :return:
    """
    speaker_audios = {}
    batch_input_mix_fea = []
    batch_input_mix_spec = []
    batch_input_spk = []
    batch_input_clean_fea = []
    batch_target_spec = []
    batch_input_len = []
    batch_count = 0
    while True:
        mix_k = np.random.randint(min_mix, max_mix+1)

        if mix_k > len(speaker_audios):
            speaker_audios = {}
            file_list = open(audio_list)
            for line in file_list:
                line = line.strip().split()
                if len(line) != 2:
                    print 'Wrong audio list file record in the line:', line
                    continue
                file_str, spk = line
                if spk not in speaker_audios:
                    speaker_audios[spk] = []
                speaker_audios[spk].append(file_str)
            file_list.close()

            for spk in speaker_audios:
                random.shuffle(speaker_audios[spk])

        wav_mix = None
        target_spk = None
        mix_len = 0
        target_sig = None

        for spk in random.sample(speaker_audios.keys(), mix_k):
            file_str = speaker_audios[spk].pop()
            if not speaker_audios[spk]:
                del(speaker_audios[spk])
            signal, rate = sf.read(file_str)
            if len(signal.shape) > 1:
                signal = signal[:, 0]
            if rate != config.FRAME_RATE:
                signal = resampy.resample(signal, rate, config.FRAME_RATE, filter='kaiser_best')
            signal = list(signal)
            if len(signal) > config.MAX_LEN:
                signal = signal[:config.MAX_LEN]
            if len(signal) > mix_len:
                mix_len = len(signal)

            signal = np.array(signal)
            signal -= np.mean(signal)
            signal /= np.max(np.abs(signal))

            signal = list(signal)

            if config.AUGMENT_DATA:
                random_shift = random.sample(range(len(signal)), 1)[0]
                signal = signal[random_shift:] + signal[:random_shift]

            if len(signal) < config.MAX_LEN:
                signal.extend(np.zeros(config.MAX_LEN - len(signal)))

            signal = np.array(signal)

            if wav_mix is None:
                wav_mix = signal
                target_sig = signal
                target_spk = spk_to_idx[spk]
            else:
                wav_mix = wav_mix + signal

        if config.IS_LOG_SPECTRAL:
            feature_mix = np.log(np.transpose(np.abs(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                                                config.FRAME_SHIFT,
                                                                                window=config.WINDOWS)))
                                 + np.spacing(1))
        else:
            feature_mix = np.transpose(np.abs(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                                         config.FRAME_SHIFT,
                                                                         window=config.WINDOWS)))

        spec_mix = np.transpose(np.abs(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                                  config.FRAME_SHIFT, window=config.WINDOWS)))

        if config.IS_LOG_SPECTRAL:
            feature_inp_clean = np.log(np.transpose(np.abs(librosa.core.spectrum.stft(target_sig, config.FRAME_LENGTH,
                                                                                      config.FRAME_SHIFT,
                                                                                      window=config.WINDOWS)))
                                       + np.spacing(1))
        else:
            feature_inp_clean = np.transpose(np.abs(librosa.core.spectrum.stft(target_sig, config.FRAME_LENGTH,
                                                                               config.FRAME_SHIFT,
                                                                               window=config.WINDOWS)))

        spec_clean = np.transpose(np.abs(librosa.core.spectrum.stft(target_sig, config.FRAME_LENGTH,
                                                                    config.FRAME_SHIFT, window=config.WINDOWS)))


        batch_input_mix_fea.append(feature_mix)
        batch_input_mix_spec.append(spec_mix)
        batch_input_spk.append(target_spk)
        batch_input_clean_fea.append(feature_inp_clean)
        batch_target_spec.append(spec_clean)
        batch_input_len.append(mix_len)
        batch_count += 1

        if batch_count == batch_size:
            # mix_input_fea (batch_size, time_steps, feature_dim)
            mix_input_fea = np.array(batch_input_mix_fea).reshape((batch_size, ) + feature_mix.shape)
            # mix_input_spec (batch_size, time_steps, spectrum_dim)
            mix_input_spec = np.array(batch_input_mix_spec).reshape((batch_size, ) + spec_mix.shape)
            # target_input_spk (batch_size, 1)
            target_input_spk = np.array(batch_input_spk, dtype=np.int32).reshape((batch_size, 1))
            # clean_input_fea (batch_size, time_steps, feature_dim)
            clean_input_fea = np.array(batch_input_clean_fea).reshape((batch_size, ) + feature_inp_clean.shape)
            # clean_target_spec (batch_size, time_steps, spectrum_dim)
            clean_target_spec = np.array(batch_target_spec).reshape((batch_size, ) + spec_clean.shape)

            yield ({'input_mix_feature': mix_input_fea, 'input_mix_spectrum': mix_input_spec,
                    'input_target_spk': target_input_spk, 'input_clean_feature': clean_input_fea},
                   {'target_clean_spectrum': clean_target_spec})
            batch_input_mix_fea = []
            batch_input_mix_spec = []
            batch_input_spk = []
            batch_input_clean_fea = []
            batch_target_spec = []
            batch_input_len = []
            batch_count = 0
Example #58
0
def resample(y, orig_sr, target_sr, res_type='kaiser_best', fix=True, scale=False, **kwargs):
    """Resample a time series from orig_sr to target_sr

    Parameters
    ----------
    y : np.ndarray [shape=(n,) or shape=(2, n)]
        audio time series.  Can be mono or stereo.

    orig_sr : number > 0 [scalar]
        original sampling rate of `y`

    target_sr : number > 0 [scalar]
        target sampling rate

    res_type : str
        resample type (see note)

        .. note::
            By default, this uses `resampy`'s high-quality mode ('kaiser_best').

            To use a faster method, set `res_type='kaiser_fast'`.

            To use `scipy.signal.resample`, set `res_type='fft'` or `res_type='scipy'`.

            To use `scipy.signal.resample_poly`, set `res_type='polyphase'`.

        .. note::
            When using `res_type='polyphase'`, only integer sampling rates are
            supported.

    fix : bool
        adjust the length of the resampled signal to be of size exactly
        `ceil(target_sr * len(y) / orig_sr)`

    scale : bool
        Scale the resampled signal so that `y` and `y_hat` have approximately
        equal total energy.

    kwargs : additional keyword arguments
        If `fix==True`, additional keyword arguments to pass to
        `librosa.util.fix_length`.

    Returns
    -------
    y_hat : np.ndarray [shape=(n * target_sr / orig_sr,)]
        `y` resampled from `orig_sr` to `target_sr`

    Raises
    ------
    ParameterError
        If `res_type='polyphase'` and `orig_sr` or `target_sr` are not both
        integer-valued.

    See Also
    --------
    librosa.util.fix_length
    scipy.signal.resample
    resampy.resample

    Notes
    -----
    This function caches at level 20.

    Examples
    --------
    Downsample from 22 KHz to 8 KHz

    >>> y, sr = librosa.load(librosa.util.example_audio_file(), sr=22050)
    >>> y_8k = librosa.resample(y, sr, 8000)
    >>> y.shape, y_8k.shape
    ((1355168,), (491671,))
    """

    # First, validate the audio buffer
    util.valid_audio(y, mono=False)

    if orig_sr == target_sr:
        return y

    ratio = float(target_sr) / orig_sr

    n_samples = int(np.ceil(y.shape[-1] * ratio))

    if res_type in ('scipy', 'fft'):
        y_hat = scipy.signal.resample(y, n_samples, axis=-1)
    elif res_type == 'polyphase':
        if int(orig_sr) != orig_sr or int(target_sr) != target_sr:
            raise ParameterError('polyphase resampling is only supported for integer-valued sampling rates.')

        # For polyphase resampling, we need up- and down-sampling ratios
        # We can get those from the greatest common divisor of the rates
        # as long as the rates are integrable
        orig_sr = int(orig_sr)
        target_sr = int(target_sr)
        gcd = np.gcd(orig_sr, target_sr)
        y_hat = scipy.signal.resample_poly(y, target_sr // gcd, orig_sr // gcd, axis=-1)
    else:
        y_hat = resampy.resample(y, orig_sr, target_sr, filter=res_type, axis=-1)

    if fix:
        y_hat = util.fix_length(y_hat, n_samples, **kwargs)

    if scale:
        y_hat /= np.sqrt(ratio)

    return np.ascontiguousarray(y_hat, dtype=y.dtype)
Example #59
0
def eval_loss(model, audio_list, valid_test, epoch_num, log_file, spk_to_idx, batch_size=1, unk_spk=False):
    if unk_spk:
        batch_size = 1
    batch_input_mix_fea = []
    batch_input_mix_spec = []
    batch_input_spk = []
    batch_target_spec = []
    batch_input_len = []
    batch_clean_wav = []
    batch_count = 0
    mse_loss = 0

    file_list_len = 0
    file_list = open(audio_list)
    for line in file_list:
        file_list_len += 1
    file_list.close()
    file_list = open(audio_list)
    time_start = time.time()
    for line_idx, line in enumerate(file_list):
        line = line.strip().split()
        if len(line) < 2:
            raise Exception('Wrong audio list file record in the line:', ''.join(line))
        file_tar_sounds_str = None
        if not unk_spk:
            file_tar_str, file_bg_str, tar_spk_str = line
            file_bg_str = file_bg_str.strip().split(',')[0]
        else:
            file_tar_str, file_bg_str, tar_spk_str, file_tar_sounds_str = line

        wav_mix = None
        target_spk = None
        mix_len = 0
        target_sig = None
        tar_supp_sig = None

        for file_str in [file_tar_str, file_bg_str]:
            signal, rate = sf.read(file_str)
            if len(signal.shape) > 1:
                signal = signal[:, 0]
            if rate != config.FRAME_RATE:
                signal = resampy.resample(signal, rate, config.FRAME_RATE, filter='kaiser_best')
            signal = list(signal)
            if len(signal) > config.MAX_LEN:
                signal = signal[:config.MAX_LEN]
            if len(signal) > mix_len:
                mix_len = len(signal)

            signal = np.array(signal)
            signal -= np.mean(signal)
            signal /= np.max(np.abs(signal))

            signal = list(signal)
            if len(signal) < config.MAX_LEN:
                signal.extend(np.zeros(config.MAX_LEN - len(signal)))
            signal = np.array(signal)

            if wav_mix is None:
                wav_mix = signal
                target_sig = signal
                if not unk_spk:
                    tar_supp_sig = signal
                    batch_clean_wav.append(tar_supp_sig)
                    target_spk = spk_to_idx[tar_spk_str]
                else:
                    target_spk = 0
            else:
                wav_mix = wav_mix + signal

        if unk_spk:
            tmp_unk_spk_supp = 0
            for file_str in file_tar_sounds_str.strip().split(','):
                tmp_unk_spk_supp += 1
                if tmp_unk_spk_supp > config.UNK_SPK_SUPP:
                    break
                signal, rate = sf.read(file_str)
                if len(signal.shape) > 1:
                    signal = signal[:, 0]
                if rate != config.FRAME_RATE:
                    signal = resampy.resample(signal, rate, config.FRAME_RATE, filter='kaiser_best')
                if tar_supp_sig is None:
                    tar_supp_sig = signal
                else:
                    tar_supp_sig = tar_supp_sig + signal
            batch_clean_wav.append(tar_supp_sig)

        if config.IS_LOG_SPECTRAL:
            feature_mix = np.log(np.transpose(np.abs(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                                                config.FRAME_SHIFT,
                                                                                window=config.WINDOWS)))
                                 + np.spacing(1))
        else:
            feature_mix = np.transpose(np.abs(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                                         config.FRAME_SHIFT,
                                                                         window=config.WINDOWS)))

        spec_mix = np.transpose(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                           config.FRAME_SHIFT,
                                                           window=config.WINDOWS))

        spec_clean = np.transpose(np.abs(librosa.core.spectrum.stft(target_sig, config.FRAME_LENGTH,
                                                                    config.FRAME_SHIFT, window=config.WINDOWS)))

        batch_input_mix_fea.append(feature_mix)
        batch_input_mix_spec.append(np.abs(spec_mix))
        batch_input_spk.append(target_spk)
        batch_target_spec.append(spec_clean)
        batch_input_len.append(mix_len)

        batch_count += 1

        if (batch_count == batch_size) or (line_idx == (file_list_len-1)):
            # mix_input_fea (batch_size, time_steps, feature_dim)
            _tmp_batch_size = len(batch_input_mix_fea)
            mix_input_fea = np.array(batch_input_mix_fea).reshape((_tmp_batch_size, ) + feature_mix.shape)
            # mix_input_spec (batch_size, time_steps, spectrum_dim)
            mix_input_spec = np.array(batch_input_mix_spec).reshape((_tmp_batch_size, ) + spec_mix.shape)
            # target_input_spk (batch_size, 1)
            target_input_spk = np.array(batch_input_spk).reshape((_tmp_batch_size, 1))
            # clean_input_fea (batch_size, time_steps, feature_dim)
            batch_input_clean_fea, inp_clean_shape = compute_batch_clean_fea(batch_clean_wav)
            clean_input_fea = np.array(batch_input_clean_fea).reshape((_tmp_batch_size, ) + inp_clean_shape)
            # clean_target_spec (batch_size, time_steps, spectrum_dim)
            clean_target_spec = np.array(batch_target_spec).reshape((_tmp_batch_size, ) + spec_clean.shape)

            if not unk_spk:
                clean_input_fea = np.zeros_like(clean_input_fea)

            mse_loss += model.evaluate({'input_mix_feature': mix_input_fea, 'input_mix_spectrum': mix_input_spec,
                                        'input_target_spk': target_input_spk, 'input_clean_feature': clean_input_fea},
                                       {'target_clean_spectrum': clean_target_spec}, batch_size=_tmp_batch_size,
                                       verbose=0)

            time_end = time.time()
            print '\rCurrent evaluate:' + str(line_idx+1) + ' of ' + audio_list + \
                ' and cost time: %.4f sec.' % (time_end - time_start),

            batch_input_mix_fea = []
            batch_input_mix_spec = []
            batch_input_spk = []
            batch_target_spec = []
            batch_input_len = []
            batch_clean_wav = []
            batch_count = 0

    print '\n[Epoch-%s: %d] - MSE Loss:%f' % \
          (valid_test, epoch_num+1, mse_loss)
    log_file.write('[Epoch-%s: %d] - MSE Loss:%f\n' %
                   (valid_test, epoch_num+1, mse_loss))
    log_file.flush()
    file_list.close()
    return mse_loss
Example #60
0
def eval_separation(model, audio_list, valid_test, epoch_num, log_file, spk_to_idx, batch_size=1, spk_num=2
                    , unk_spk=False, supp_time=1, add_bgd_noise=False):
    if unk_spk:
        batch_size = 1
    if spk_num < 2:
        spk_num = 2
    batch_input_mix_fea = []
    batch_input_mix_spec = []
    batch_input_spk = []
    batch_input_len = []
    batch_mix_spec = []
    batch_mix_wav = []
    batch_target_wav = []
    batch_noise_wav = []
    batch_clean_wav = []
    batch_count = 0

    batch_sdr_0 = []
    batch_sir_0 = []
    batch_sar_0 = []
    batch_nsdr_0 = []
    batch_sdr = []
    batch_sir = []
    batch_sar = []
    batch_nsdr = []

    file_list_len = 0
    file_list = open(audio_list)
    for line in file_list:
        file_list_len += 1
    file_list.close()
    file_list = open(audio_list)
    time_start = time.time()
    for line_idx, line in enumerate(file_list):
        line = line.strip().split()
        if len(line) < 2:
            raise Exception('Wrong audio list file record in the line:', ''.join(line))
        file_tar_sounds_str = None
        if not unk_spk:
            # if not test unk_spk
            file_tar_str, file_bg_str, tar_spk_str = line
            file_bg_str = file_bg_str.strip().split(',')
        else:
            # if test unk_spk
            file_tar_str, file_bg_str, tar_spk_str, file_tar_sounds_str = line
            file_bg_str = [file_bg_str]
        wav_mix = None
        target_spk = None
        mix_len = 0
        target_sig = None
        noise_sig = None
        tar_supp_sig = None

        for file_str in ([file_tar_str]+file_bg_str)[:spk_num]:
            signal, rate = sf.read(file_str)
            if len(signal.shape) > 1:
                signal = signal[:, 0]
            if rate != config.FRAME_RATE:
                signal = resampy.resample(signal, rate, config.FRAME_RATE, filter='kaiser_best')
            signal = list(signal)
            if len(signal) > config.MAX_LEN:
                signal = signal[:config.MAX_LEN]
            if len(signal) > mix_len:
                mix_len = len(signal)

            signal = np.array(signal)
            signal -= np.mean(signal)
            signal /= np.max(np.abs(signal))

            signal = list(signal)
            if len(signal) < config.MAX_LEN:
                signal.extend(np.zeros(config.MAX_LEN - len(signal)))
            signal = np.array(signal)

            if wav_mix is None:
                wav_mix = signal
                target_sig = signal
                if not unk_spk:
                    tar_supp_sig = signal
                    batch_clean_wav.append(tar_supp_sig)
                    target_spk = spk_to_idx[tar_spk_str]
                else:
                    # idx of unk_spk: 0
                    target_spk = 0
            else:
                wav_mix = wav_mix + signal
                if noise_sig is None:
                    noise_sig = signal
                else:
                    noise_sig = noise_sig + signal

        if add_bgd_noise:
            bg_noise = config.BGD_NOISE_WAV[:config.MAX_LEN]
            bg_noise -= np.mean(bg_noise)
            bg_noise /= np.max(np.abs(bg_noise))

            wav_mix = wav_mix + bg_noise
            noise_sig = noise_sig + bg_noise

        if unk_spk:
            tmp_unk_spk_supp = 0
            for file_str in file_tar_sounds_str.strip().split(','):
                tmp_unk_spk_supp += 1
                if tmp_unk_spk_supp > config.UNK_SPK_SUPP:
                    break
                signal, rate = sf.read(file_str)
                if len(signal.shape) > 1:
                    signal = signal[:, 0]
                if rate != config.FRAME_RATE:
                    signal = resampy.resample(signal, rate, config.FRAME_RATE, filter='kaiser_best')
                signal = list(signal)
                if tar_supp_sig is None:
                    tar_supp_sig = signal
                else:
                    tar_supp_sig = tar_supp_sig + signal
            if len(tar_supp_sig) < supp_time*config.FRAME_RATE:
                raise Exception('the supp_time is too greater than the target supplemental sounds!')
            batch_clean_wav.append(tar_supp_sig[:int(supp_time * config.FRAME_RATE)])

        if config.IS_LOG_SPECTRAL:
            feature_mix = np.log(np.abs(np.transpose(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                                                config.FRAME_SHIFT,
                                                                                window=config.WINDOWS)))
                                 + np.spacing(1))
        else:
            feature_mix = np.abs(np.transpose(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                                         config.FRAME_SHIFT,
                                                                         window=config.WINDOWS)))

        spec_mix = np.transpose(librosa.core.spectrum.stft(wav_mix, config.FRAME_LENGTH,
                                                           config.FRAME_SHIFT,
                                                           window=config.WINDOWS))

        batch_input_mix_fea.append(feature_mix)
        batch_input_mix_spec.append(np.abs(spec_mix))
        batch_input_spk.append(target_spk)
        batch_input_len.append(mix_len)
        batch_mix_spec.append(spec_mix)
        batch_mix_wav.append(wav_mix)
        batch_target_wav.append(target_sig)
        batch_noise_wav.append(noise_sig)

        batch_count += 1

        if (batch_count == batch_size) or (line_idx == (file_list_len-1)):
            # mix_input_fea (batch_size, time_steps, feature_dim)
            _tmp_batch_size = len(batch_input_mix_fea)
            mix_input_fea = np.array(batch_input_mix_fea).reshape((_tmp_batch_size, ) + feature_mix.shape)
            # mix_input_spec (batch_size, time_steps, spectrum_dim)
            mix_input_spec = np.array(batch_input_mix_spec).reshape((_tmp_batch_size, ) + spec_mix.shape)
            # bg_input_mask = np.array(batch_input_silence_mask).reshape((_tmp_batch_size, ) + spec_mix.shape)
            # target_input_spk (batch_size, 1)
            target_input_spk = np.array(batch_input_spk).reshape((_tmp_batch_size, 1))
            # clean_input_fea (batch_size, time_steps, feature_dim)
            batch_input_clean_fea, inp_clean_shape = compute_batch_clean_fea(batch_clean_wav)
            clean_input_fea = np.array(batch_input_clean_fea).reshape((batch_size, ) + inp_clean_shape)
            if not unk_spk:
                clean_input_fea = np.log(np.zeros_like(clean_input_fea)+np.spacing(1))
            target_pred = model.predict({'input_mix_feature': mix_input_fea, 'input_mix_spectrum': mix_input_spec,
                                         'input_target_spk': target_input_spk, 'input_clean_feature': clean_input_fea})
            batch_idx = 0
            for _pred_output in list(target_pred):
                _mix_spec = batch_mix_spec[batch_idx]
                phase_mix = np.angle(_mix_spec)
                _pred_spec = _pred_output * np.exp(1j * phase_mix)
                _pred_wav = librosa.core.spectrum.istft(np.transpose(_pred_spec), config.FRAME_SHIFT,
                                                        window=config.WINDOWS)
                _target_wav = batch_target_wav[batch_idx]
                min_len = np.min((len(_target_wav), len(_pred_wav), batch_input_len[batch_idx]))
                _pred_wav = _pred_wav[:min_len]
                batch_target_wav[batch_idx] = _target_wav[:min_len]
                batch_noise_wav[batch_idx] = batch_noise_wav[batch_idx][:min_len]
                batch_mix_wav[batch_idx] = batch_mix_wav[batch_idx][:min_len]

                mix_wav = matlab.double(batch_mix_wav[batch_idx].tolist())
                target_wav = matlab.double(batch_target_wav[batch_idx].tolist())
                noise_wav = matlab.double(batch_noise_wav[batch_idx].tolist())
                pred_wav = matlab.double(_pred_wav.tolist())
                if epoch_num == 0:
                    # BSS_EVAL (truth_signal, truth_noise, pred_signal, mix)
                    bss_eval_resuts = config.MAT_ENG.BSS_EVAL(target_wav, noise_wav, mix_wav, mix_wav)
                    batch_sdr_0.append(bss_eval_resuts['SDR'])
                    batch_sir_0.append(bss_eval_resuts['SIR'])
                    batch_sar_0.append(bss_eval_resuts['SAR'])
                    batch_nsdr_0.append(bss_eval_resuts['NSDR'])
                if (line_idx < _tmp_batch_size) and (batch_idx == 0):
                    sf.write(config.TMP_PRED_WAV_FOLDER + '/test_pred_%s_ep%04d_bs%04d_idx%03d' %
                             (config.DATASET, (epoch_num+1), 1, (batch_idx+1)) +
                             '.wav', _pred_wav, config.FRAME_RATE)

                # BSS_EVAL (truth_signal, truth_noise, pred_signal, mix)
                bss_eval_resuts = config.MAT_ENG.BSS_EVAL(target_wav, noise_wav, pred_wav, mix_wav)
                batch_sdr.append(bss_eval_resuts['SDR'])
                batch_sir.append(bss_eval_resuts['SIR'])
                batch_sar.append(bss_eval_resuts['SAR'])
                batch_nsdr.append(bss_eval_resuts['NSDR'])
                batch_idx += 1

            time_end = time.time()
            sdr = np.float(np.mean(batch_sdr))
            sir = np.float(np.mean(batch_sir))
            sar = np.float(np.mean(batch_sar))
            nsdr = np.float(np.mean(batch_nsdr))
            print '\rCurrent predict:' + str(line_idx+1) + ' of ' + audio_list + \
                ' and cost time: %.4f sec. - GSDR:%f, GSIR:%f, GSAR:%f, GNSDR:%f' % ((time_end - time_start),
                                                                                     sdr, sir, sar, nsdr),
            if (line_idx+1) % 200 == 0:
                log_file.write('Have evaluated %05d mixture wavs, and cost time: %.4f sec\n'
                               % ((line_idx+1), (time_end - time_start)))
                log_file.flush()
            batch_input_mix_fea = []
            batch_input_mix_spec = []
            batch_input_spk = []
            batch_input_len = []
            batch_mix_spec = []
            batch_mix_wav = []
            batch_target_wav = []
            batch_noise_wav = []
            batch_clean_wav = []
            batch_count = 0

    if epoch_num == 0:
        sdr_0 = np.float(np.mean(batch_sdr_0))
        sir_0 = np.float(np.mean(batch_sir_0))
        sar_0 = np.float(np.mean(batch_sar_0))
        nsdr_0 = np.float(np.mean(batch_nsdr_0))
        print '\n[Epoch-%s: %d] - GSDR:%f, GSIR:%f, GSAR:%f, GNSDR:%f' % \
              (valid_test, epoch_num, sdr_0, sir_0, sar_0, nsdr_0)
        log_file.write('[Epoch-%s: %d] - GSDR:%f, GSIR:%f, GSAR:%f, GNSDR:%f\n' %
                       (valid_test, epoch_num, sdr_0, sir_0, sar_0, nsdr_0))
        log_file.flush()

    sdr = np.float(np.mean(batch_sdr))
    sir = np.float(np.mean(batch_sir))
    sar = np.float(np.mean(batch_sar))
    nsdr = np.float(np.mean(batch_nsdr))
    if epoch_num == 0:
        print '[Epoch-%s: %d] - GSDR:%f, GSIR:%f, GSAR:%f, GNSDR:%f' % \
              (valid_test, epoch_num+1, sdr, sir, sar, nsdr)
    else:
        print '\n[Epoch-%s: %d] - GSDR:%f, GSIR:%f, GSAR:%f, GNSDR:%f' % \
              (valid_test, epoch_num+1, sdr, sir, sar, nsdr)
    log_file.write('[Epoch-%s: %d] - GSDR:%f, GSIR:%f, GSAR:%f, GNSDR:%f\n' %
                   (valid_test, epoch_num+1, sdr, sir, sar, nsdr))
    log_file.flush()
    file_list.close()