def wasserstein_distance(u_values, v_values, u_weights, v_weights, p=1.0): """Differentiable 1-D Wasserstein distance. Adapted from the scipy.stats implementation. Args: u_values: Samples from distribution `u`. Shape [batch_shape, n_samples]. v_values: Samples from distribution `v`. Shape [batch_shape, n_samples]. u_weights: Sample weights. Shape [batch_shape, n_samples]. v_weights: Sample weights. Shape [batch_shape, n_samples]. p: Degree of the distance norm. Wasserstein=1, Energy=2. Returns: The Wasserstein distance between samples. Shape [batch_shape]. """ u_sorter = tf.argsort(u_values, axis=-1) v_sorter = tf.argsort(v_values, axis=-1) all_values = tf.concat([u_values, v_values], axis=-1) all_values = tf.sort(all_values, axis=-1) # Compute the differences between pairs of successive values of u and v. deltas = spectral_ops.diff(all_values, axis=-1) # Get the respective positions of the values of u and v among the values of # both distributions. batch_dims = len(u_values.shape) - 1 gather = lambda x, i: tf.gather(x, i, axis=-1, batch_dims=batch_dims) u_cdf_indices = tf.searchsorted(gather(u_values, u_sorter), all_values[..., :-1], side='right') v_cdf_indices = tf.searchsorted(gather(v_values, v_sorter), all_values[..., :-1], side='right') # Calculate the CDFs of u and v using their weights, if specified. if u_weights is None: u_cdf = u_cdf_indices / float(u_values.shape[-1]) else: u_sorted_cumweights = tf.concat([ tf.zeros_like(u_weights)[..., 0:1], tf.cumsum(gather(u_weights, u_sorter), axis=-1) ], axis=-1) u_cdf = gather(u_sorted_cumweights, u_cdf_indices) safe_divide(u_cdf, u_sorted_cumweights[..., -1:]) if v_weights is None: v_cdf = v_cdf_indices / float(v_values.shape[-1]) else: v_sorted_cumweights = tf.concat([ tf.zeros_like(v_weights)[..., 0:1], tf.cumsum(gather(v_weights, v_sorter), axis=-1) ], axis=-1) v_cdf = gather(v_sorted_cumweights, v_cdf_indices) safe_divide(v_cdf, v_sorted_cumweights[..., -1:]) # Compute the value of the integral based on the CDFs. return tf.reduce_sum(deltas * tf.abs(u_cdf - v_cdf)**p, axis=-1)**(1.0 / p)
def nll(self, amps, freqs, amps_target, freqs_target, scale_target): """Returns negative log-likelihood of source sins given target sins. Args: amps: Amplitudes of source sinusoids, greater than 0. Shape [batch, time, freq]. freqs: Frequencies of source sinusoids in hertz. Shape [batch, time, feq]. amps_target: Amplitudes of target sinusoids, greater than 0. Shape [batch, time, freq]. freqs_target: Frequencies of target sinusoids in hertz. Shape [batch, time, feq]. scale_target: Scale of gaussian kernel in MIDI. Returns: - log(p(source|target)). Shape [batch, time]. """ p_source_given_target = self.kernel_density_estimate( amps_target, freqs_target, scale_target) # KDE is on a logarithmic scale (MIDI). freqs_midi = hz_to_midi(freqs) # Need to rearrage shape as tfp expects, [sample_sh, batch_sh, event_sh]. freqs_transpose = tf.transpose(freqs_midi, [2, 0, 1]) # [freq, batch, time] nll_transpose = - p_source_given_target.log_prob(freqs_transpose) nll = tf.transpose(nll_transpose, [1, 2, 0]) # [batch, time, freq] # Weighted sum over sinusoids -> [batch, time] amps_norm = safe_divide(amps, tf.reduce_sum(amps, axis=-1, keepdims=True)) return tf.reduce_mean(nll * amps_norm, axis=-1)
def get_note_moments(x, note_mask, return_std=True): """Return the moments of value xm, pooled over the length of the note. Args: x: Value to be pooled, [batch, time, dims] or [batch, time]. note_mask: Binary mask of notes [batch, time, notes]. return_std: Also return the standard deviation for each note. Returns: Values pooled over each note region, [batch, notes, dims] or [batch, notes]. Returns only mean if return_std=False, else mean and std. """ is_2d = len(x.shape) == 2 if is_2d: x = x[:, :, tf.newaxis] note_mask_d = note_mask[..., tf.newaxis] # [b, t, n, 1] note_lengths = tf.reduce_sum(note_mask_d, axis=1) # [b, n, 1] # Mean. x_masked = x[:, :, tf.newaxis, :] * note_mask_d # [b, t, n, d] x_mean = core.safe_divide(tf.reduce_sum(x_masked, axis=1), note_lengths) # [b, n, d] # Standard Deviation. numerator = (x[:, :, tf.newaxis, :] - x_mean[:, tf.newaxis, :, :]) * note_mask_d numerator = tf.reduce_sum(numerator**2.0, axis=1) # [b, n, d] x_std = core.safe_divide(numerator, note_lengths)**0.5 x_mean = x_mean[:, :, 0] if is_2d else x_mean x_std = x_std[:, :, 0] if is_2d else x_std if return_std: return x_mean, x_std else: return x_mean
def get_p_harmonics_given_sinusoids(self, freqs, amps): """Gets distribution of harmonics from candidate f0s given sinusoids. Performs a gaussian kernel density estimate on the sinusoid points, with the height of each gaussian component given by the sinusoidal amplitude. Args: freqs: Frequencies of sinusoids in hertz. amps: Amplitudes of sinusoids, must be greater than 0. Returns: MixtureSameFamily, Gaussian distribution. """ # Gaussian KDE around each partial, height=amplitude, center=frequency. sinusoids_midi = hz_to_midi(freqs) # NLL can be a nan if sinusoid amps are all zero, add a small offset. amps = tf.where(amps == 0.0, 1e-7 * tf.ones_like(amps), amps) amps_norm = safe_divide(amps, tf.reduce_sum(amps, axis=-1, keepdims=True)) # P(candidate_harmonics | sinusoids) return tfd.MixtureSameFamily( tfd.Categorical(probs=amps_norm), tfd.Normal(loc=sinusoids_midi, scale=self.sinusoids_scale))
def get_loss_tensors(self, f0_candidates, freqs, amps): """Get traces of loss to estimate fundamental frequency. Args: f0_candidates: Frequencies of candidates in hertz. [batch, time, freq]. freqs: Frequencies of sinusoids in hertz. [batch, time, feq]. amps: Amplitudes of sinusoids, greater than 0. [batch, time, freq]. Returns: sinusoids_loss: -log p(sinusoids|harmonics), [batch, time, f0_candidate]. harmonics_loss: - log p(harmonics|sinusoids), [batch, time, f0_candidate]. """ # ========================================================================== # P(sinusoids | candidate_harmonics). # ========================================================================== p_sinusoids_given_harmonics = self.get_p_sinusoids_given_harmonics() # Treat each partial as a candidate. # Get the ratio of each partial to each candidate. # -> [batch, time, candidate, partial] freq_ratios = safe_divide(freqs[:, :, tf.newaxis, :], f0_candidates[:, :, :, tf.newaxis]) nll_sinusoids = - p_sinusoids_given_harmonics.log_prob(freq_ratios) a = tf.convert_to_tensor(amps[:, :, tf.newaxis, :]) # # Don't count sinusoids that are less than 1 std > mean. # a_mean, a_var = tf.nn.moments(a, axes=-1, keepdims=True) # a = tf.where(a > a_mean + 0.5 * a_var**0.5, a, tf.zeros_like(a)) # Weighted sum by sinusoid amplitude. # -> [batch, time, candidate] sinusoids_loss = safe_divide(tf.reduce_sum(nll_sinusoids * a, axis=-1), tf.reduce_sum(a, axis=-1)) # ========================================================================== # P(candidate_harmonics | sinusoids) # ========================================================================== p_harm_given_sin = self.get_p_harmonics_given_sinusoids(freqs, amps) harmonics = self.get_candidate_harmonics(f0_candidates, as_midi=True) # Need to rearrage shape as tfp expects, [sample_sh, batch_sh, event_sh]. # -> [candidate, harmonic, batch, time] harmonics_transpose = tf.transpose(harmonics, [2, 3, 0, 1]) nll_harmonics_transpose = - p_harm_given_sin.log_prob(harmonics_transpose) # -> [batch, time, candidate, harm] nll_harmonics = tf.transpose(nll_harmonics_transpose, [2, 3, 0, 1]) # Prior decreasing importance of upper harmonics. amps_prior = tf.linspace( 1.0, 1.0 / self.n_harmonic_points, self.n_harmonic_points) harmonics_loss = (nll_harmonics * amps_prior[tf.newaxis, tf.newaxis, tf.newaxis, :]) # Don't count loss for harmonics above nyquist. # Reweight by the number of harmonics below nyquist, # (so it doesn't just pick the highest frequency possible). nyquist_midi = hz_to_midi(self.sample_rate / 2.0) nyquist_mask = tf.where(harmonics < nyquist_midi, tf.ones_like(harmonics_loss), tf.zeros_like(harmonics_loss)) harmonics_loss *= safe_divide( nyquist_mask, tf.reduce_mean(nyquist_mask, axis=-1, keepdims=True)) # Sum over harmonics. harmonics_loss = tf.reduce_mean(harmonics_loss, axis=-1) return sinusoids_loss, harmonics_loss