Esempio n. 1
0
    def call(self, target_audio, audio):

        loss = 0.0
        loss_ops = []
        diff = spectral_ops.diff

        for size in self.fft_sizes:
            loss_op = functools.partial(spectral_ops.compute_mag, size=size)
            loss_ops.append(loss_op)

        # Compute loss for each fft size.
        for loss_op in loss_ops:
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type)

            if self.delta_time_weight > 0:
                target = diff(target_mag, axis=1)
                value = diff(value_mag, axis=1)
                loss += self.delta_time_weight * mean_difference(
                    target, value, self.loss_type)

            if self.delta_delta_time_weight > 0:
                target = diff(diff(target_mag, axis=1), axis=1)
                value = diff(diff(value_mag, axis=1), axis=1)
                loss += self.delta_delta_time_weight * mean_difference(
                    target, value, self.loss_type)

            if self.delta_freq_weight > 0:
                target = diff(target_mag, axis=2)
                value = diff(value_mag, axis=2)
                loss += self.delta_freq_weight * mean_difference(
                    target, value, self.loss_type)

            if self.delta_delta_freq_weight > 0:
                target = diff(diff(target_mag, axis=2), axis=2)
                value = diff(diff(value_mag, axis=2), axis=2)
                loss += self.delta_delta_freq_weight * mean_difference(
                    target, value, self.loss_type)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type)

        if self.loudness_weight > 0:
            target = spectral_ops.compute_loudness(target_audio, n_fft=2048)
            value = spectral_ops.compute_loudness(audio, n_fft=2048)
            loss += self.loudness_weight * mean_difference(
                target, value, self.loss_type)

        return loss
Esempio n. 2
0
    def call(self, target_audio, audio):

        loss = 0.0

        diff = spectral_ops.diff
        cumsum = tf.math.cumsum

        # Compute loss for each fft size.
        for loss_op in self.spectrogram_ops:
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type)

            if self.delta_time_weight > 0:
                target = diff(target_mag, axis=1)
                value = diff(value_mag, axis=1)
                loss += self.delta_time_weight * mean_difference(
                    target, value, self.loss_type)

            if self.delta_freq_weight > 0:
                target = diff(target_mag, axis=2)
                value = diff(value_mag, axis=2)
                loss += self.delta_freq_weight * mean_difference(
                    target, value, self.loss_type)

            # TODO(kyriacos) normalize cumulative spectrogram
            if self.cumsum_freq_weight > 0:
                target = cumsum(target_mag, axis=2)
                value = cumsum(value_mag, axis=2)
                loss += self.cumsum_freq_weight * mean_difference(
                    target, value, self.loss_type)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type)

        if self.loudness_weight > 0:
            target = spectral_ops.compute_loudness(target_audio,
                                                   n_fft=2048,
                                                   use_tf=True)
            value = spectral_ops.compute_loudness(audio,
                                                  n_fft=2048,
                                                  use_tf=True)
            loss += self.loudness_weight * mean_difference(
                target, value, self.loss_type)

        return loss
Esempio n. 3
0
    def call(self, audio, target_audio):
        loss = 0.0
        loss_ops = []

        f_n = self.sample_rate / 2
        f_max = f_n
        f_min = 0.0
        m_max = self._hz_to_mel(f_max)
        m_min = self._hz_to_mel(f_min)
        n_mels_max = int((m_max - m_min) / self.n_bands / 4)
        m_all = np.linspace(m_min, m_max, self.n_bands + 1)
        m_los = m_all[:-1]
        m_his = m_all[1:]
        f_los = self._mel_to_hz(m_los)
        f_his = self._mel_to_hz(m_his)
        d_m = (m_his - m_los) / n_mels_max
        d_fs = self._df_dm(m_los) * d_m
        for i, f_lo in enumerate(f_los):
            f_hi = f_his[i]
            d_f = d_fs[i]
            for j, n_fft in enumerate(
                    self._get_closest_n_fft(self.sample_rate, d_f,
                                            self.N_FFT_OPTIONS)):
                n_mels = int(n_mels_max / 2**(2 * j))
                loss_op = functools.partial(compute_mel,
                                            sample_rate=self.sample_rate,
                                            lo_hz=f_lo,
                                            hi_hz=f_hi,
                                            bins=n_mels,
                                            fft_size=n_fft)
                loss_ops.append(loss_op)

        # Compute loss for each fft size.
        for i, loss_op in enumerate(loss_ops):
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type)

        return loss
Esempio n. 4
0
    def call(self, audio, target_audio):
        loss = 0.0
        loss_ops = []

        n_layers = len(self.fft_sizes)
        f_n = self.sample_rate / 2
        f_bands_ids = np.arange(0,
                                self.n_bands).repeat(n_layers / self.n_bands)
        band_width = f_n / self.n_bands
        for i, n_fft in enumerate(self.fft_sizes):
            n_mels = int(
                n_fft / 16
            )  # TODO: this is ad-hoc; change for something more motivated
            f_lo = f_bands_ids[i] * band_width
            f_hi = f_lo + band_width
            loss_op = functools.partial(compute_mel,
                                        sample_rate=self.sample_rate,
                                        lo_hz=f_lo,
                                        hi_hz=f_hi,
                                        bins=n_mels,
                                        fft_size=n_fft)
            loss_ops.append(loss_op)

        # Compute loss for each fft size.
        for i, loss_op in enumerate(loss_ops):
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type)

        return loss
Esempio n. 5
0
    def call(self, audio, target_audio):
        loss = 0.0

        # Compute loss for each fft size.
        for loss_op in self.loss_ops:
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type)

        return loss
Esempio n. 6
0
 def get_log_mel_spectrum(*args, **kwargs):
     mel = get_mel_spectrum(*args, **kwargs)
     return spectral_ops.safe_log(mel)
Esempio n. 7
0
    def call(self, target_audio, audio, weights=None):
        loss = 0.0

        diff = spectral_ops.diff
        cumsum = tf.math.cumsum

        # Compute loss for each fft size.
        for loss_op in self.spectrogram_ops:
            target_mag = loss_op(target_audio)
            value_mag = loss_op(audio)

            # Add magnitude loss.
            if self.mag_weight > 0:
                loss += self.mag_weight * mean_difference(
                    target_mag, value_mag, self.loss_type, weights=weights)

            if self.delta_time_weight > 0:
                target = diff(target_mag, axis=1)
                value = diff(value_mag, axis=1)
                loss += self.delta_time_weight * mean_difference(
                    target, value, self.loss_type, weights=weights)

            if self.delta_freq_weight > 0:
                target = diff(target_mag, axis=2)
                value = diff(value_mag, axis=2)
                loss += self.delta_freq_weight * mean_difference(
                    target, value, self.loss_type, weights=weights)

            # TODO(kyriacos) normalize cumulative spectrogram
            if self.cumsum_freq_weight > 0:
                target = cumsum(target_mag, axis=-1)
                value = cumsum(value_mag, axis=-1)
                loss += self.cumsum_freq_weight * mean_difference(
                    target, value, self.loss_type, weights=weights)

            if self.bin_time_weight > 0:
                target = tf.reduce_sum(target_mag, axis=-1)
                value = tf.reduce_sum(value_mag, axis=-1)
                # target = tf.cumsum(target, axis=-1)
                # value = tf.cumsum(value, axis=-1)
                loss += self.bin_time_weight * mean_difference(
                    target, value, self.loss_type, weights=weights)
                # times = tf.cast(tf.linspace(0, 1, tf.shape(target)[-1]), dtype=tf.float32)
                # target = target / tf.reduce_sum(target, axis=-1)
                # value = value / tf.reduce_sum(value, axis=-1)
                # target = tf.cast(tf.expand_dims(target, axis=1), dtype=tf.float32)
                # value = tf.cast(tf.expand_dims(value, axis=1), dtype=tf.float32)
                # loss += self.bin_time_weight * tf.reduce_mean(wasserstein_distance(times, times, target, value))

            if self.max_power_weight > 0:
                target = spectral_ops.safe_log(
                    tf.reduce_max(target_mag, axis=2))
                value = spectral_ops.safe_log(tf.reduce_max(value_mag, axis=2))
                loss += self.max_power_weight * mean_difference(
                    target, value, self.loss_type, weights=weights)

            # Add logmagnitude loss, reusing spectrogram.
            if self.logmag_weight > 0:
                target = spectral_ops.safe_log(target_mag)
                value = spectral_ops.safe_log(value_mag)
                loss += self.logmag_weight * mean_difference(
                    target, value, self.loss_type, weights=weights)

            if self.mel_weight > 0 or self.logmel_weight > 0:
                target_mel = spectral_ops.compute_mel_from_mag(
                    target_mag,
                    lo_hz=2.0,
                    bins=None,
                    fft_size=loss_op.keywords['size'])
                value_mel = spectral_ops.compute_mel_from_mag(
                    value_mag,
                    lo_hz=2.0,
                    bins=None,
                    fft_size=loss_op.keywords['size'])
                if self.mel_weight > 0:
                    loss += self.mel_weight * mean_difference(
                        target_mel, value_mel, self.loss_type, weights=weights)
                if self.logmel_weight > 0:
                    target_logmel = spectral_ops.safe_log(target_mel)
                    value_logmel = spectral_ops.safe_log(value_mel)
                    loss += self.logmel_weight * mean_difference(
                        target_logmel,
                        value_logmel,
                        self.loss_type,
                        weights=weights)

        if self.loudness_weight > 0:
            target = spectral_ops.compute_loudness(target_audio,
                                                   n_fft=2048,
                                                   use_tf=True)
            value = spectral_ops.compute_loudness(audio,
                                                  n_fft=2048,
                                                  use_tf=True)
            loss += self.loudness_weight * mean_difference(
                target, value, self.loss_type, weights=weights)

        return loss