示例#1
0
def error_heatmap_summary(audio,
                          audio_gen,
                          fft_sizes=(768, ),
                          log_smooth=('logmag', None, 100),
                          **data):
    unskewed_loss = UnskewedSpectralLoss(fft_sizes=(2048, 1024, 512, 256, 128,
                                                    64),
                                         loss_type='L1',
                                         log_smooth=[None, 1, 10, 100])
    total_loss = unskewed_loss(audio, audio_gen, keep_batch_dim=True)
    for size in fft_sizes:
        target_mag = ddsp.spectral_ops.compute_mag(tf_float32(audio),
                                                   size=size)
        value_mag = ddsp.spectral_ops.compute_mag(tf_float32(audio_gen),
                                                  size=size)

        for i in range(len(target_mag)):
            for s in log_smooth:
                if s is None:
                    t, v = scale(target_mag, value_mag)
                    title = f"Magnitude Spectrum"
                elif s == 'logmag':
                    t = safe_log(target_mag)
                    v = safe_log(value_mag)
                    t, v = scale(t, v)
                    title = f"Logmag Spectrum"
                else:
                    t = unskew(target_mag, s)
                    v = unskew(value_mag, s)
                    t, v = scale(t, v)
                    title = f"Magnitude Spectrum"
                title += f" | Loss = {float(total_loss[i]):.3f}"
                img = get_error_heatmap(t[i], v[i])
                fig, ax = plt.subplots(1, 1, figsize=(8, 8))
                img = np.rot90(img)
                ax.imshow(img, aspect='auto')
                ax.set_title(title)
                ax.set_xticks([])
                ax.set_yticks([])
                # Format and save plot to image
                fig_summary(fig,
                            plot_type='error_heatmap',
                            sample_idx=i + 1,
                            s=s,
                            **data)
示例#2
0
def compute_logmel(audio,
                   lo_hz=80.0,
                   hi_hz=7600.0,
                   bins=64,
                   fft_size=2048,
                   overlap=0.75,
                   pad_end=True):
    mel = compute_mel(audio, lo_hz, hi_hz, bins, fft_size, overlap, pad_end)
    return safe_log(mel)
示例#3
0
def compute_logmel(audio,
                   lo_hz=80.0,
                   hi_hz=7600.0,
                   bins=64,
                   fft_size=2048,
                   overlap=0.75,
                   pad_end=True,
                   sample_rate=16000):
  """Logarithmic amplitude of mel-scaled spectrogram."""
  mel = compute_mel(audio, lo_hz, hi_hz, bins,
                    fft_size, overlap, pad_end, sample_rate)
  return safe_log(mel)
示例#4
0
def error_heatmap(audio,
                  audio_gen,
                  step=12000,
                  name='',
                  tag='error_heatmap',
                  fft_sizes=(768, ),
                  log_smooth=('logmag', None, 1, 10, 100)):
    for size in fft_sizes:
        target_mag = ddsp.spectral_ops.compute_mag(tf_float32(audio),
                                                   size=size)
        value_mag = ddsp.spectral_ops.compute_mag(tf_float32(audio_gen),
                                                  size=size)

        for i in range(len(target_mag)):
            for s in log_smooth:
                if s is None:
                    t, v = scale(target_mag, value_mag)
                    title = f"Magnitude Spectrum ({size})"
                elif s == 'logmag':
                    t = safe_log(target_mag)
                    v = safe_log(value_mag)
                    t, v = scale(t, v)
                    title = f"Logmag Spectrum ({size})"
                else:
                    t = unskew(target_mag, s)
                    v = unskew(value_mag, s)
                    t, v = scale(t, v)
                    title = f"Magnitude Spectrum ({size}) with s={s}"
                j = mean_difference(t, v, 'L1')
                title += f" (diff = {j:.3f})"
                img = get_error_heatmap(t[i], v[i])
                fig, ax = plt.subplots(1, 1, figsize=(8, 8))
                img = np.rot90(img)
                ax.imshow(img, aspect='auto')
                ax.set_title(title)
                ax.set_xticks([])
                ax.set_yticks([])
                # Format and save plot to image
                tag_i = f'{tag}/{name}{i+1}-s={s}'
示例#5
0
def compute_mfcc_from_mag(
    mag,
    lo_hz=20.0,
    hi_hz=8000.0,
    mel_bins=128,
    mfcc_bins=13,
):
    """Calculate Mel Spectrogram."""
    num_spectrogram_bins = int(mag.shape[-1])
    linear_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix(
        mel_bins, num_spectrogram_bins, 16000, lo_hz, hi_hz)
    mel = tf.tensordot(mag, linear_to_mel_matrix, 1)
    mel.set_shape(mag.shape[:-1].concatenate(linear_to_mel_matrix.shape[-1:]))
    logmel = safe_log(mel)
    mfccs = tf.signal.mfccs_from_log_mel_spectrograms(logmel)
    return mfccs[..., :mfcc_bins]
示例#6
0
def compute_logmag(audio, size=2048, overlap=0.75, pad_end=True):
    return safe_log(compute_mag(audio, size, overlap, pad_end))
示例#7
0
        ori_stft = multiscale_fft(
            s,
            config["train"]["scales"],
            config["train"]["overlap"],
        )
        rec_stft = multiscale_fft(
            y,
            config["train"]["scales"],
            config["train"]["overlap"],
        )

        loss = 0
        for s_x, s_y in zip(ori_stft, rec_stft):
            lin_loss = (s_x - s_y).abs().mean()
            log_loss = (safe_log(s_x) - safe_log(s_y)).abs().mean()
            loss = loss + lin_loss + log_loss

        opt.zero_grad()
        loss.backward()
        opt.step()

        writer.add_scalar("loss", loss.item(), step)

        step += 1

        n_element += 1
        mean_loss += (loss.item() - mean_loss) / n_element

    if not e % 10:
        writer.add_scalar("lr", schedule(e), e)