def wasserstein_distance(u_values, v_values, u_weights, v_weights, p=1.0): """Differentiable 1-D Wasserstein distance. Adapted from the scipy.stats implementation. Args: u_values: Samples from distribution `u`. Shape [batch_shape, n_samples]. v_values: Samples from distribution `v`. Shape [batch_shape, n_samples]. u_weights: Sample weights. Shape [batch_shape, n_samples]. v_weights: Sample weights. Shape [batch_shape, n_samples]. p: Degree of the distance norm. Wasserstein=1, Energy=2. Returns: The Wasserstein distance between samples. Shape [batch_shape]. """ u_sorter = tf.argsort(u_values, axis=-1) v_sorter = tf.argsort(v_values, axis=-1) all_values = tf.concat([u_values, v_values], axis=-1) all_values = tf.sort(all_values, axis=-1) # Compute the differences between pairs of successive values of u and v. deltas = spectral_ops.diff(all_values, axis=-1) # Get the respective positions of the values of u and v among the values of # both distributions. batch_dims = len(u_values.shape) - 1 gather = lambda x, i: tf.gather(x, i, axis=-1, batch_dims=batch_dims) u_cdf_indices = tf.searchsorted(gather(u_values, u_sorter), all_values[..., :-1], side='right') v_cdf_indices = tf.searchsorted(gather(v_values, v_sorter), all_values[..., :-1], side='right') # Calculate the CDFs of u and v using their weights, if specified. if u_weights is None: u_cdf = u_cdf_indices / float(u_values.shape[-1]) else: u_sorted_cumweights = tf.concat([ tf.zeros_like(u_weights)[..., 0:1], tf.cumsum(gather(u_weights, u_sorter), axis=-1) ], axis=-1) u_cdf = gather(u_sorted_cumweights, u_cdf_indices) safe_divide(u_cdf, u_sorted_cumweights[..., -1:]) if v_weights is None: v_cdf = v_cdf_indices / float(v_values.shape[-1]) else: v_sorted_cumweights = tf.concat([ tf.zeros_like(v_weights)[..., 0:1], tf.cumsum(gather(v_weights, v_sorter), axis=-1) ], axis=-1) v_cdf = gather(v_sorted_cumweights, v_cdf_indices) safe_divide(v_cdf, v_sorted_cumweights[..., -1:]) # Compute the value of the integral based on the CDFs. return tf.reduce_sum(deltas * tf.abs(u_cdf - v_cdf)**p, axis=-1)**(1.0 / p)
def get_note_mask(q_pitch, max_regions=100, note_on_only=True): """Get a binary mask for each note from a monophonic instrument. Each transition of the value creates a new region. Returns the mask of each region. Args: q_pitch: A quantized value, such as pitch or velocity. Shape [batch, n_timesteps] or [batch, n_timesteps, 1]. max_regions: Maximum number of note regions to consider in the sequence. Also, the channel dimension of the output mask. Each value transition defines a new region, e.g. each note-on and note-off count as a separate region. note_on_only: Return a mask that is true only for regions where the pitch is greater than 0. Returns: A binary mask of each region [batch, n_timesteps, max_regions]. """ # Only batch and time dimensions. if len(q_pitch.shape) == 3: q_pitch = q_pitch[:, :, 0] # Get onset and offset points. edges = tf.abs(spectral_ops.diff(q_pitch, axis=1)) > 0 # Count endpoints as starts/ends of regions. edges = edges[:, :-1, ...] edges = tf.pad(edges, [[0, 0], [1, 0]], mode='constant', constant_values=True) edges = tf.pad(edges, [[0, 0], [0, 1]], mode='constant', constant_values=False) edges = tf.cast(edges, tf.int32) # Count up onset and offsets for each timestep. # Assumes each onset has a corresponding offset. # The -1 ensures that the 0th index is the first note. edge_idx = tf.cumsum(edges, axis=1) - 1 # Create masks of shape [batch, n_timesteps, max_regions]. note_mask = edge_idx[..., None] == tf.range(max_regions)[None, None, :] note_mask = tf.cast(note_mask, tf.float32) if note_on_only: # [batch, notes] note_pitches = get_note_moments(q_pitch, note_mask, return_std=False) # [batch, time, notes] note_on = tf.cast(note_pitches > 0.0, tf.float32)[:, None, :] # [batch, time, notes] note_mask *= note_on return note_mask