def call(self, conditioning): batch_size = conditioning['f0_hz'].shape[0] noise = tf.random.normal([batch_size, self.n_total, 1]) f0_hz = core.resample(conditioning['f0_hz'], self.n_total) frequency_envelopes = core.get_harmonic_frequencies(f0_hz, self.n_harmonics) audios = core.oscillator_bank(frequency_envelopes=frequency_envelopes, amplitude_envelopes=tf.ones_like(frequency_envelopes), sample_rate=self.sample_rate, sum_sinusoids=False) inputs = [conditioning[k] for k in self.input_keys] inputs = [stack(x) for stack, x in zip(self.input_stacks, inputs)] # Resample all inputs to the target sample rate inputs = [core.resample(x, self.n_total) for x in inputs] c = tf.concat(inputs + [audios, noise], axis=-1) # Conv layers x = self.first_conv(c) skips = 0 for f in self.conv_layers: x, h = f(x, c) skips += h skips *= tf.sqrt(1.0 / len(self.conv_layers)) return {'audio_tensor': self.dense_out(skips)}
def get_controls(self, amplitudes, harmonic_distribution, f0_hz): """Convert network output tensors into a dictionary of synthesizer controls. Args: amplitudes: 3-D Tensor of synthesizer controls, of shape [batch, time, 1]. harmonic_distribution: 3-D Tensor of synthesizer controls, of shape [batch, time, n_harmonics]. f0_hz: Fundamental frequencies in hertz. Shape [batch, time, 1]. Returns: controls: Dictionary of tensors of synthesizer controls. """ # Scale the amplitudes. if self.scale_fn is not None: amplitudes = self.scale_fn(amplitudes) harmonic_distribution = self.scale_fn(harmonic_distribution) # Bandlimit the harmonic distribution. if self.normalize_below_nyquist: n_harmonics = int(harmonic_distribution.shape[-1]) harmonic_frequencies = core.get_harmonic_frequencies( f0_hz, n_harmonics) harmonic_distribution = core.remove_above_nyquist( harmonic_frequencies, harmonic_distribution, self.sample_rate) # Normalize harmonic_distribution /= ( tf.reduce_sum(harmonic_distribution, axis=-1, keepdims=True) + 1e-9) if self.half_f0: f0_hz = f0_hz / 2 return { 'amplitudes': amplitudes, 'harmonic_distribution': harmonic_distribution, 'f0_hz': f0_hz }