示例#1
0
def _freq_feat_graph(feat_name, **kwargs):
    winlen = kwargs.get('winlen')
    winstep = kwargs.get('winstep')
    feature_size = kwargs.get('feature_size')
    sr = kwargs.get('sr')  #pylint: disable=invalid-name
    nfft = kwargs.get('nfft')
    del nfft

    assert feat_name in ('fbank', 'spec')

    params = speech_ops.speech_params(sr=sr,
                                      bins=feature_size,
                                      add_delta_deltas=False,
                                      audio_frame_length=winlen,
                                      audio_frame_step=winstep)

    graph = None
    if feat_name == 'fbank':
        # get session
        if feat_name not in _global_sess:
            graph = tf.Graph()
            #pylint: disable=not-context-manager
            with graph.as_default():
                # fbank
                filepath = tf.placeholder(dtype=tf.string,
                                          shape=[],
                                          name='wavpath')
                waveforms, sample_rate = speech_ops.read_wav(filepath, params)
                del sample_rate
                fbank = speech_ops.extract_feature(waveforms, params)
                # shape must be [T, D, C]
                feat = tf.identity(fbank, name=feat_name)
    elif feat_name == 'spec':
        # magnitude spec
        if feat_name not in _global_sess:
            graph = tf.Graph()
            #pylint: disable=not-context-manager
            with graph.as_default():
                filepath = tf.placeholder(dtype=tf.string,
                                          shape=[],
                                          name='wavpath')
                waveforms, sample_rate = speech_ops.read_wav(filepath, params)

                spec = py_x_ops.spectrum(
                    waveforms[:, 0],
                    tf.cast(sample_rate, tf.dtypes.float32),
                    output_type=1
                )  #output_type: 1, power spec; 2 log power spec
                spec = tf.sqrt(spec)
                # shape must be [T, D, C]
                spec = tf.expand_dims(spec, -1)
                feat = tf.identity(spec, name=feat_name)
    else:
        raise ValueError(f"Not support freq feat: {feat_name}.")

    return graph, (_get_out_tensor_name('wavpath',
                                        0), _get_out_tensor_name(feat_name, 0))
示例#2
0
def extract_filterbank(*args, **kwargs):
  ''' tensorflow fbank feat '''
  winlen = kwargs.get('winlen')
  winstep = kwargs.get('winstep')
  feature_size = kwargs.get('feature_size')
  sr = kwargs.get('sr')  #pylint: disable=invalid-name
  nfft = kwargs.get('nfft')
  dry_run = kwargs.get('dry_run')
  del nfft

  feat_name = 'fbank'
  graph = None
  op = None
  # get session
  if feat_name not in _global_sess:
    graph = tf.Graph()
    #pylint: disable=not-context-manager
    with graph.as_default():
      # fbank
      params = speech_ops.speech_params(
          sr=sr,
          bins=feature_size,
          add_delta_deltas=False,
          audio_frame_length=winlen,
          audio_frame_step=winstep)

      filepath = tf.placeholder(dtype=tf.string, shape=[], name='wavpath')
      waveforms, sample_rate = speech_ops.read_wav(filepath, params)
      del sample_rate
      fbank = speech_ops.extract_feature(waveforms, params)
      fbank = tf.identity(fbank, name=feat_name)

  sess = _get_session(_get_out_tensor_name(feat_name, 0), graph)

  for wavpath in args:
    savepath = os.path.splitext(wavpath)[0] + '.npy'
    logging.debug('input: {}, output: {}'.format(wavpath, savepath))

    feat = sess.run(feat_name + ":0", feed_dict={'wavpath:0': wavpath})

    # save feat
    if dry_run:
      logging.info('save feat: path {} shape:{} dtype:{}'.format(
          savepath, feat.shape, feat.dtype))
    else:
      np.save(savepath, feat)
    def test_extract_feature(self):
        ''' test logfbank with delta, and cmvn '''
        #pylint: disable=invalid-name
        hp = tffeat.speech_params(sr=self.sr_true,
                                  bins=40,
                                  cmvn=False,
                                  audio_desired_samples=1000,
                                  add_delta_deltas=True)

        with self.session():
            wavfile = tf.constant(self.wavpath)
            # read wav
            audio, sample_rate = tffeat.read_wav(wavfile, hp)
            del sample_rate

            # fbank with delta delta and cmvn
            feature = tffeat.extract_feature(audio, hp)

            self.assertEqual(feature.eval().shape, (11, 40, 3))