Example #1
0
def wasserstein_distance(u_values, v_values, u_weights, v_weights, p=1.0):
    """Differentiable 1-D Wasserstein distance.

  Adapted from the scipy.stats implementation.
  Args:
    u_values: Samples from distribution `u`. Shape [batch_shape, n_samples].
    v_values: Samples from distribution `v`. Shape [batch_shape, n_samples].
    u_weights: Sample weights. Shape [batch_shape, n_samples].
    v_weights: Sample weights. Shape [batch_shape, n_samples].
    p: Degree of the distance norm. Wasserstein=1, Energy=2.

  Returns:
    The Wasserstein distance between samples. Shape [batch_shape].
  """
    u_sorter = tf.argsort(u_values, axis=-1)
    v_sorter = tf.argsort(v_values, axis=-1)

    all_values = tf.concat([u_values, v_values], axis=-1)
    all_values = tf.sort(all_values, axis=-1)

    # Compute the differences between pairs of successive values of u and v.
    deltas = spectral_ops.diff(all_values, axis=-1)

    # Get the respective positions of the values of u and v among the values of
    # both distributions.
    batch_dims = len(u_values.shape) - 1
    gather = lambda x, i: tf.gather(x, i, axis=-1, batch_dims=batch_dims)
    u_cdf_indices = tf.searchsorted(gather(u_values, u_sorter),
                                    all_values[..., :-1],
                                    side='right')
    v_cdf_indices = tf.searchsorted(gather(v_values, v_sorter),
                                    all_values[..., :-1],
                                    side='right')

    # Calculate the CDFs of u and v using their weights, if specified.
    if u_weights is None:
        u_cdf = u_cdf_indices / float(u_values.shape[-1])
    else:
        u_sorted_cumweights = tf.concat([
            tf.zeros_like(u_weights)[..., 0:1],
            tf.cumsum(gather(u_weights, u_sorter), axis=-1)
        ],
                                        axis=-1)
        u_cdf = gather(u_sorted_cumweights, u_cdf_indices)
        safe_divide(u_cdf, u_sorted_cumweights[..., -1:])

    if v_weights is None:
        v_cdf = v_cdf_indices / float(v_values.shape[-1])
    else:
        v_sorted_cumweights = tf.concat([
            tf.zeros_like(v_weights)[..., 0:1],
            tf.cumsum(gather(v_weights, v_sorter), axis=-1)
        ],
                                        axis=-1)
        v_cdf = gather(v_sorted_cumweights, v_cdf_indices)
        safe_divide(v_cdf, v_sorted_cumweights[..., -1:])

    # Compute the value of the integral based on the CDFs.
    return tf.reduce_sum(deltas * tf.abs(u_cdf - v_cdf)**p, axis=-1)**(1.0 / p)
Example #2
0
def get_note_mask(q_pitch, max_regions=100, note_on_only=True):
    """Get a binary mask for each note from a monophonic instrument.

  Each transition of the value creates a new region. Returns the mask of each
  region.
  Args:
    q_pitch: A quantized value, such as pitch or velocity. Shape
      [batch, n_timesteps] or [batch, n_timesteps, 1].
    max_regions: Maximum number of note regions to consider in the sequence.
      Also, the channel dimension of the output mask. Each value transition
      defines a new region, e.g. each note-on and note-off count as a separate
      region.
    note_on_only: Return a mask that is true only for regions where the pitch
      is greater than 0.

  Returns:
    A binary mask of each region [batch, n_timesteps, max_regions].
  """
    # Only batch and time dimensions.
    if len(q_pitch.shape) == 3:
        q_pitch = q_pitch[:, :, 0]

    # Get onset and offset points.
    edges = tf.abs(spectral_ops.diff(q_pitch, axis=1)) > 0

    # Count endpoints as starts/ends of regions.
    edges = edges[:, :-1, ...]
    edges = tf.pad(edges, [[0, 0], [1, 0]],
                   mode='constant',
                   constant_values=True)
    edges = tf.pad(edges, [[0, 0], [0, 1]],
                   mode='constant',
                   constant_values=False)
    edges = tf.cast(edges, tf.int32)

    # Count up onset and offsets for each timestep.
    # Assumes each onset has a corresponding offset.
    # The -1 ensures that the 0th index is the first note.
    edge_idx = tf.cumsum(edges, axis=1) - 1

    # Create masks of shape [batch, n_timesteps, max_regions].
    note_mask = edge_idx[..., None] == tf.range(max_regions)[None, None, :]
    note_mask = tf.cast(note_mask, tf.float32)

    if note_on_only:
        # [batch, notes]
        note_pitches = get_note_moments(q_pitch, note_mask, return_std=False)
        # [batch, time, notes]
        note_on = tf.cast(note_pitches > 0.0, tf.float32)[:, None, :]
        # [batch, time, notes]
        note_mask *= note_on

    return note_mask