Example #1
0
    def call(self, audio, target_audio):
        loss = 0.0
        loss_ops = []

        f_n = self.sample_rate / 2
        f_max = f_n
        f_min = 0.0
        m_max = self._hz_to_mel(f_max)
        m_min = self._hz_to_mel(f_min)
        n_mels_max = int((m_max - m_min) / self.n_bands / 4)
        m_all = np.linspace(m_min, m_max, self.n_bands + 1)
        m_los = m_all[:-1]
        m_his = m_all[1:]
        f_los = self._mel_to_hz(m_los)
        f_his = self._mel_to_hz(m_his)
        d_m = (m_his - m_los) / n_mels_max
        d_fs = self._df_dm(m_los) * d_m
        for i, f_lo in enumerate(f_los):
            f_hi = f_his[i]
            d_f = d_fs[i]
            for j, n_fft in enumerate(
                    self._get_closest_n_fft(self.sample_rate, d_f,
                                            self.N_FFT_OPTIONS)):
                n_mels = int(n_mels_max / 2**(2 * j))
                loss_op = functools.partial(compute_mel,
                                            sample_rate=self.sample_rate,
                                            lo_hz=f_lo,
                                            hi_hz=f_hi,
                                            bins=n_mels,
                                            fft_size=n_fft)
                loss_ops.append(loss_op)

        # Compute loss for each fft size.
        for i, loss_op in enumerate(loss_ops):
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type)

        return loss
Example #2
0
    def call(self, audio, target_audio):
        loss = 0.0
        loss_ops = []

        n_layers = len(self.fft_sizes)
        f_n = self.sample_rate / 2
        f_bands_ids = np.arange(0,
                                self.n_bands).repeat(n_layers / self.n_bands)
        band_width = f_n / self.n_bands
        for i, n_fft in enumerate(self.fft_sizes):
            n_mels = int(
                n_fft / 16
            )  # TODO: this is ad-hoc; change for something more motivated
            f_lo = f_bands_ids[i] * band_width
            f_hi = f_lo + band_width
            loss_op = functools.partial(compute_mel,
                                        sample_rate=self.sample_rate,
                                        lo_hz=f_lo,
                                        hi_hz=f_hi,
                                        bins=n_mels,
                                        fft_size=n_fft)
            loss_ops.append(loss_op)

        # Compute loss for each fft size.
        for i, loss_op in enumerate(loss_ops):
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type)

        return loss
Example #3
0
    def call(self, audio, target_audio):
        loss = 0.0

        # Compute loss for each fft size.
        for loss_op in self.loss_ops:
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type)

        return loss
Example #4
0
  def get_loss(self, target_mag, value_mag, weights=None, keep_batch=False):
    """Computes a loss for each timestep"""
    loss = 0

    if self.mag_weight > 0:
      loss += self.mag_weight * mean_difference(
        target_mag, value_mag, self.loss_type, weights=weights, axis=[2, 3])
    
    if self.spectral_centroid_weight > 0:
      target = spectral_ops.spectral_centroid(target_mag)
      value = spectral_ops.spectral_centroid(value_mag) 
      loss += self.spectral_centroid_weight * mean_difference(
        target, value, self.loss_type, weights=weights, axis=[2])
    
    if self.blurred_spectral_weight > 0.0:
      target = blur(target_mag)
      value = blur(value_mag)
      loss += self.blurred_spectral_weight * mean_difference(
        target, value, self.loss_type, weights=weights, axis=[2, 3])
    
    return loss
Example #5
0
def error_heatmap(audio,
                  audio_gen,
                  step=12000,
                  name='',
                  tag='error_heatmap',
                  fft_sizes=(768, ),
                  log_smooth=('logmag', None, 1, 10, 100)):
    for size in fft_sizes:
        target_mag = ddsp.spectral_ops.compute_mag(tf_float32(audio),
                                                   size=size)
        value_mag = ddsp.spectral_ops.compute_mag(tf_float32(audio_gen),
                                                  size=size)

        for i in range(len(target_mag)):
            for s in log_smooth:
                if s is None:
                    t, v = scale(target_mag, value_mag)
                    title = f"Magnitude Spectrum ({size})"
                elif s == 'logmag':
                    t = safe_log(target_mag)
                    v = safe_log(value_mag)
                    t, v = scale(t, v)
                    title = f"Logmag Spectrum ({size})"
                else:
                    t = unskew(target_mag, s)
                    v = unskew(value_mag, s)
                    t, v = scale(t, v)
                    title = f"Magnitude Spectrum ({size}) with s={s}"
                j = mean_difference(t, v, 'L1')
                title += f" (diff = {j:.3f})"
                img = get_error_heatmap(t[i], v[i])
                fig, ax = plt.subplots(1, 1, figsize=(8, 8))
                img = np.rot90(img)
                ax.imshow(img, aspect='auto')
                ax.set_title(title)
                ax.set_xticks([])
                ax.set_yticks([])
                # Format and save plot to image
                tag_i = f'{tag}/{name}{i+1}-s={s}'
Example #6
0
 def committment_loss(self, z, z_q):
     """Encourage encoder to output embeddings close to the current centroids."""
     loss = losses.mean_difference(z, tf.stop_gradient(z_q), loss_type='L2')
     return self.commitment_loss_weight * loss