Example #1
0
def wasserstein_distance(u_values, v_values, u_weights, v_weights, p=1.0):
    """Differentiable 1-D Wasserstein distance.

  Adapted from the scipy.stats implementation.
  Args:
    u_values: Samples from distribution `u`. Shape [batch_shape, n_samples].
    v_values: Samples from distribution `v`. Shape [batch_shape, n_samples].
    u_weights: Sample weights. Shape [batch_shape, n_samples].
    v_weights: Sample weights. Shape [batch_shape, n_samples].
    p: Degree of the distance norm. Wasserstein=1, Energy=2.

  Returns:
    The Wasserstein distance between samples. Shape [batch_shape].
  """
    u_sorter = tf.argsort(u_values, axis=-1)
    v_sorter = tf.argsort(v_values, axis=-1)

    all_values = tf.concat([u_values, v_values], axis=-1)
    all_values = tf.sort(all_values, axis=-1)

    # Compute the differences between pairs of successive values of u and v.
    deltas = spectral_ops.diff(all_values, axis=-1)

    # Get the respective positions of the values of u and v among the values of
    # both distributions.
    batch_dims = len(u_values.shape) - 1
    gather = lambda x, i: tf.gather(x, i, axis=-1, batch_dims=batch_dims)
    u_cdf_indices = tf.searchsorted(gather(u_values, u_sorter),
                                    all_values[..., :-1],
                                    side='right')
    v_cdf_indices = tf.searchsorted(gather(v_values, v_sorter),
                                    all_values[..., :-1],
                                    side='right')

    # Calculate the CDFs of u and v using their weights, if specified.
    if u_weights is None:
        u_cdf = u_cdf_indices / float(u_values.shape[-1])
    else:
        u_sorted_cumweights = tf.concat([
            tf.zeros_like(u_weights)[..., 0:1],
            tf.cumsum(gather(u_weights, u_sorter), axis=-1)
        ],
                                        axis=-1)
        u_cdf = gather(u_sorted_cumweights, u_cdf_indices)
        safe_divide(u_cdf, u_sorted_cumweights[..., -1:])

    if v_weights is None:
        v_cdf = v_cdf_indices / float(v_values.shape[-1])
    else:
        v_sorted_cumweights = tf.concat([
            tf.zeros_like(v_weights)[..., 0:1],
            tf.cumsum(gather(v_weights, v_sorter), axis=-1)
        ],
                                        axis=-1)
        v_cdf = gather(v_sorted_cumweights, v_cdf_indices)
        safe_divide(v_cdf, v_sorted_cumweights[..., -1:])

    # Compute the value of the integral based on the CDFs.
    return tf.reduce_sum(deltas * tf.abs(u_cdf - v_cdf)**p, axis=-1)**(1.0 / p)
Example #2
0
  def nll(self, amps, freqs, amps_target, freqs_target, scale_target):
    """Returns negative log-likelihood of source sins given target sins.

    Args:
      amps: Amplitudes of source sinusoids, greater than 0.
        Shape [batch, time, freq].
      freqs: Frequencies of source sinusoids in hertz.
        Shape [batch, time, feq].
      amps_target: Amplitudes of target sinusoids, greater than 0.
        Shape [batch, time, freq].
      freqs_target: Frequencies of target sinusoids in hertz.
        Shape [batch, time, feq].
      scale_target: Scale of gaussian kernel in MIDI.

    Returns:
      - log(p(source|target)). Shape [batch, time].
    """
    p_source_given_target = self.kernel_density_estimate(
        amps_target, freqs_target, scale_target)

    # KDE is on a logarithmic scale (MIDI).
    freqs_midi = hz_to_midi(freqs)

    # Need to rearrage shape as tfp expects, [sample_sh, batch_sh, event_sh].
    freqs_transpose = tf.transpose(freqs_midi, [2, 0, 1])  # [freq, batch, time]
    nll_transpose = - p_source_given_target.log_prob(freqs_transpose)
    nll = tf.transpose(nll_transpose, [1, 2, 0])  # [batch, time, freq]

    # Weighted sum over sinusoids -> [batch, time]
    amps_norm = safe_divide(amps, tf.reduce_sum(amps, axis=-1, keepdims=True))
    return tf.reduce_mean(nll * amps_norm, axis=-1)
Example #3
0
def get_note_moments(x, note_mask, return_std=True):
    """Return the moments of value xm, pooled over the length of the note.

  Args:
    x: Value to be pooled, [batch, time, dims] or [batch, time].
    note_mask: Binary mask of notes [batch, time, notes].
    return_std: Also return the standard deviation for each note.

  Returns:
    Values pooled over each note region, [batch, notes, dims] or [batch, notes].
    Returns only mean if return_std=False, else mean and std.
  """
    is_2d = len(x.shape) == 2
    if is_2d:
        x = x[:, :, tf.newaxis]

    note_mask_d = note_mask[..., tf.newaxis]  # [b, t, n, 1]
    note_lengths = tf.reduce_sum(note_mask_d, axis=1)  # [b, n, 1]

    # Mean.
    x_masked = x[:, :, tf.newaxis, :] * note_mask_d  # [b, t, n, d]
    x_mean = core.safe_divide(tf.reduce_sum(x_masked, axis=1),
                              note_lengths)  # [b, n, d]

    # Standard Deviation.
    numerator = (x[:, :, tf.newaxis, :] -
                 x_mean[:, tf.newaxis, :, :]) * note_mask_d
    numerator = tf.reduce_sum(numerator**2.0, axis=1)  # [b, n, d]
    x_std = core.safe_divide(numerator, note_lengths)**0.5

    x_mean = x_mean[:, :, 0] if is_2d else x_mean
    x_std = x_std[:, :, 0] if is_2d else x_std

    if return_std:
        return x_mean, x_std
    else:
        return x_mean
Example #4
0
  def get_p_harmonics_given_sinusoids(self, freqs, amps):
    """Gets distribution of harmonics from candidate f0s given sinusoids.

    Performs a gaussian kernel density estimate on the sinusoid points, with the
    height of each gaussian component given by the sinusoidal amplitude.
    Args:
      freqs: Frequencies of sinusoids in hertz.
      amps: Amplitudes of sinusoids, must be greater than 0.

    Returns:
      MixtureSameFamily, Gaussian distribution.
    """
    # Gaussian KDE around each partial, height=amplitude, center=frequency.
    sinusoids_midi = hz_to_midi(freqs)

    # NLL can be a nan if sinusoid amps are all zero, add a small offset.
    amps = tf.where(amps == 0.0, 1e-7 * tf.ones_like(amps), amps)
    amps_norm = safe_divide(amps, tf.reduce_sum(amps, axis=-1, keepdims=True))

    # P(candidate_harmonics | sinusoids)
    return tfd.MixtureSameFamily(
        tfd.Categorical(probs=amps_norm),
        tfd.Normal(loc=sinusoids_midi, scale=self.sinusoids_scale))
Example #5
0
  def get_loss_tensors(self, f0_candidates, freqs, amps):
    """Get traces of loss to estimate fundamental frequency.

    Args:
      f0_candidates: Frequencies of candidates in hertz. [batch, time, freq].
      freqs: Frequencies of sinusoids in hertz. [batch, time, feq].
      amps: Amplitudes of sinusoids, greater than 0. [batch, time, freq].

    Returns:
      sinusoids_loss: -log p(sinusoids|harmonics), [batch, time, f0_candidate].
      harmonics_loss: - log p(harmonics|sinusoids), [batch, time, f0_candidate].
    """
    # ==========================================================================
    # P(sinusoids | candidate_harmonics).
    # ==========================================================================
    p_sinusoids_given_harmonics = self.get_p_sinusoids_given_harmonics()

    # Treat each partial as a candidate.
    # Get the ratio of each partial to each candidate.
    # -> [batch, time, candidate, partial]
    freq_ratios = safe_divide(freqs[:, :, tf.newaxis, :],
                              f0_candidates[:, :, :, tf.newaxis])
    nll_sinusoids = - p_sinusoids_given_harmonics.log_prob(freq_ratios)

    a = tf.convert_to_tensor(amps[:, :, tf.newaxis, :])

    # # Don't count sinusoids that are less than 1 std > mean.
    # a_mean, a_var = tf.nn.moments(a, axes=-1, keepdims=True)
    # a = tf.where(a > a_mean + 0.5 * a_var**0.5, a, tf.zeros_like(a))

    # Weighted sum by sinusoid amplitude.
    # -> [batch, time, candidate]
    sinusoids_loss = safe_divide(tf.reduce_sum(nll_sinusoids * a, axis=-1),
                                 tf.reduce_sum(a, axis=-1))

    # ==========================================================================
    # P(candidate_harmonics | sinusoids)
    # ==========================================================================
    p_harm_given_sin = self.get_p_harmonics_given_sinusoids(freqs, amps)
    harmonics = self.get_candidate_harmonics(f0_candidates, as_midi=True)

    # Need to rearrage shape as tfp expects, [sample_sh, batch_sh, event_sh].
    # -> [candidate, harmonic, batch, time]
    harmonics_transpose = tf.transpose(harmonics, [2, 3, 0, 1])
    nll_harmonics_transpose = - p_harm_given_sin.log_prob(harmonics_transpose)
    # -> [batch, time, candidate, harm]
    nll_harmonics = tf.transpose(nll_harmonics_transpose, [2, 3, 0, 1])

    # Prior decreasing importance of upper harmonics.
    amps_prior = tf.linspace(
        1.0, 1.0 / self.n_harmonic_points, self.n_harmonic_points)
    harmonics_loss = (nll_harmonics *
                      amps_prior[tf.newaxis, tf.newaxis, tf.newaxis, :])

    # Don't count loss for harmonics above nyquist.
    # Reweight by the number of harmonics below nyquist,
    # (so it doesn't just pick the highest frequency possible).
    nyquist_midi = hz_to_midi(self.sample_rate / 2.0)
    nyquist_mask = tf.where(harmonics < nyquist_midi,
                            tf.ones_like(harmonics_loss),
                            tf.zeros_like(harmonics_loss))
    harmonics_loss *= safe_divide(
        nyquist_mask, tf.reduce_mean(nyquist_mask, axis=-1, keepdims=True))

    # Sum over harmonics.
    harmonics_loss = tf.reduce_mean(harmonics_loss, axis=-1)

    return sinusoids_loss, harmonics_loss