def build(self, input_shape): if self.image_data_format == 'channels_first': self.n_ch = input_shape[1] self.n_freq = input_shape[2] self.n_time = input_shape[3] else: self.n_ch = input_shape[3] self.n_freq = input_shape[1] self.n_time = input_shape[2] if self.init == 'mel': self.filterbank = K.variable(backend.filterbank_mel(sr=self.sr, n_freq=self.n_freq, n_mels=self.n_fbs, fmin=self.fmin, fmax=self.fmax).transpose(), dtype=K.floatx()) elif self.init == 'log': self.filterbank = K.variable(backend.filterbank_log(sr=self.sr, n_freq=self.n_freq, n_bins=self.n_fbs, bins_per_octave=self.bins_per_octave, fmin=self.fmin).transpose(), dtype=K.floatx()) if self.trainable_fb: self.trainable_weights.append(self.filterbank) else: self.non_trainable_weights.append(self.filterbank) super(Filterbank, self).build(input_shape) self.built = True
def test_filterbank_log(): """test for backend.filterback_log""" fblog_ref = np.load( os.path.join(os.path.dirname(__file__), 'fblog_8000_512.npy')) fblog = KPB.filterbank_log(sr=8000, n_freq=512) assert fblog.shape == fblog_ref.shape assert np.allclose(fblog, fblog_ref, atol=TOL)
def test_filterbank_log(sample_rate, n_freq, n_bins, bins_per_octave, f_min, spread): """It only tests if the function is a valid wrapper""" log_fb = KPB.filterbank_log( sample_rate=sample_rate, n_freq=n_freq, n_bins=n_bins, bins_per_octave=bins_per_octave, f_min=f_min, spread=spread, ) assert log_fb.dtype == K.floatx() assert log_fb.shape == (n_freq, n_bins)
def test_fb_log_fail(): _ = KPB.filterbank_log(sample_rate=22050, n_freq=513, n_bins=300, bins_per_octave=12)