Example #1
0
    def update(self,
               expert_dataset_iter,
               policy_dataset_iter,
               discount,
               replay_regularization=0.05,
               nu_reg=10.0):
        """A function that updates nu network.

    When replay regularization is non-zero, it learns
    (d_pi * (1 - replay_regularization) + d_rb * replay_regulazation) /
    (d_expert * (1 - replay_regularization) + d_rb * replay_regulazation)
    instead.

    Args:
      expert_dataset_iter: An tensorflow graph iteratable over expert data.
      policy_dataset_iter: An tensorflow graph iteratable over training policy
        data, used for regularization.
      discount: An MDP discount.
      replay_regularization: A fraction of samples to add from a replay buffer.
      nu_reg: A grad penalty regularization coefficient.
    """

        (expert_states, expert_actions,
         expert_next_states) = expert_dataset_iter.get_next()

        expert_initial_states = expert_states

        # rb_states, rb_actions, rb_next_states, _, _ = policy_dataset_iter.get_next(
        # )[0]

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch(self.actor.variables)
            tape.watch(self.nu_net.variables)

            _, policy_next_actions, _ = self.actor(expert_next_states)
            # _, rb_next_actions, rb_log_prob = self.actor(rb_next_states)

            _, policy_initial_actions, _ = self.actor(expert_initial_states)

            # Inputs for the linear part of DualDICE loss.
            expert_init_inputs = tf.concat(
                [expert_initial_states, policy_initial_actions], 1)

            if not self.discrete:
                expert_inputs = tf.concat([expert_states, expert_actions], 1)
            else:
                mat = tf.one_hot(tf.cast(expert_actions, tf.int32),
                                 depth=self.action_dim,
                                 axis=-1)
                expert_inputs = tf.concat([expert_states, mat], 1)
            expert_next_inputs = tf.concat(
                [expert_next_states, policy_next_actions], 1)

            # rb_inputs = tf.concat([rb_states, rb_actions], 1)
            # rb_next_inputs = tf.concat([rb_next_states, rb_next_actions], 1)

            expert_nu_0 = self.nu_net(expert_init_inputs)
            expert_nu = self.nu_net(expert_inputs)
            expert_nu_next = self.nu_net(expert_next_inputs)

            # rb_nu = self.nu_net(rb_inputs)
            # rb_nu_next = self.nu_net(rb_next_inputs)

            expert_diff = expert_nu - discount * expert_nu_next
            # rb_diff = rb_nu - discount * rb_nu_next

            linear_loss_expert = tf.reduce_mean(expert_nu_0 * (1 - discount))

            # linear_loss_rb = tf.reduce_mean(rb_diff)

            rb_expert_diff = expert_diff  #tf.concat([expert_diff, rb_diff], 0)
            rb_expert_weights = tf.ones(expert_diff.shape)  #tf.concat([
            #     tf.ones(expert_diff.shape) * (1 - replay_regularization),
            #     tf.ones(rb_diff.shape) * replay_regularization
            # ], 0)

            rb_expert_weights /= tf.reduce_sum(rb_expert_weights)
            non_linear_loss = tf.reduce_sum(
                tf.stop_gradient(
                    weighted_softmax(rb_expert_diff, rb_expert_weights,
                                     axis=0)) * rb_expert_diff)

            linear_loss = (linear_loss_expert * (1 - replay_regularization) +
                           0)
            # linear_loss_rb * replay_regularization)

            loss = (non_linear_loss - linear_loss)

            alpha = tf.random.uniform(shape=(expert_inputs.shape[0], 1))

            # nu_inter = alpha * expert_inputs + (1 - alpha) * expert_init_inputs #rb_inputs
            # nu_next_inter = alpha * expert_next_inputs + (1 - alpha) * #rb_next_inputs

            # nu_inter = tf.concat([nu_inter, nu_next_inter], 0)
            nu_inter = alpha * expert_inputs + (1 - alpha) * tf.stop_gradient(
                tf.random.shuffle(expert_next_inputs))

            with tf.GradientTape(watch_accessed_variables=False) as tape2:
                tape2.watch(nu_inter)
                nu_output = self.nu_net(nu_inter)
            nu_grad = tape2.gradient(nu_output, [nu_inter])[0] + EPS
            nu_grad_penalty = tf.reduce_mean(
                tf.square(tf.norm(nu_grad, axis=-1, keepdims=True) - 1))

            nu_loss = loss + nu_grad_penalty * nu_reg
            pi_loss = -loss + keras_utils.orthogonal_regularization(
                self.actor.trunk)

        nu_grads = tape.gradient(nu_loss, self.nu_net.variables)
        pi_grads = tape.gradient(pi_loss, self.actor.variables)

        self.nu_optimizer.apply_gradients(zip(nu_grads, self.nu_net.variables))
        self.actor_optimizer.apply_gradients(
            zip(pi_grads, self.actor.variables))

        del tape

        self.avg_nu_expert(expert_nu)
        #self.avg_nu_rb(rb_nu)

        self.nu_reg_metric(nu_grad_penalty)
        self.avg_loss(loss)

        self.avg_actor_loss(pi_loss)
        #self.avg_actor_entropy(-rb_log_prob)

        if tf.equal(self.nu_optimizer.iterations % self.log_interval, 0):
            tf.summary.scalar('train dual dice/loss',
                              self.avg_loss.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_loss)

            tf.summary.scalar('train dual dice/nu expert',
                              self.avg_nu_expert.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_nu_expert)

            tf.summary.scalar('train dual dice/nu rb',
                              self.avg_nu_rb.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_nu_rb)

            tf.summary.scalar('train dual dice/nu reg',
                              self.nu_reg_metric.result(),
                              step=self.nu_optimizer.iterations)
            keras_utils.my_reset_states(self.nu_reg_metric)

        if tf.equal(self.actor_optimizer.iterations % self.log_interval, 0):
            tf.summary.scalar('train sac/actor_loss',
                              self.avg_actor_loss.result(),
                              step=self.actor_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_actor_loss)

            tf.summary.scalar('train sac/actor entropy',
                              self.avg_actor_entropy.result(),
                              step=self.actor_optimizer.iterations)
            keras_utils.my_reset_states(self.avg_actor_entropy)
Example #2
0
 def _zeros_like(x):
     return x * tf.stop_gradient(x - 1.) - tf.stop_gradient(x *
                                                            (x - 1.))
Example #3
0
 def committment_loss(self, z, z_q):
   """Encourage encoder to output embeddings close to the current centroids."""
   loss = losses.mean_difference(z, tf.stop_gradient(z_q), loss_type='L2')
   return self.commitment_loss_weight * loss
Example #4
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            seeds = list(samplers.split_seed(seed, n=len(current_state_parts)))
            seeds = distribute_lib.fold_in_axis_index(
                seeds, self.experimental_shard_axis_names)

            current_momentum_parts = []
            for part_seed, x in zip(seeds, current_state_parts):
                current_momentum_parts.append(
                    samplers.normal(shape=ps.shape(x),
                                    dtype=self._momentum_dtype
                                    or dtype_util.base_dtype(x.dtype),
                                    seed=part_seed))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = ps.rank(current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts,
                    next_momentum_parts,
                    independent_chain_ndims,
                    shard_axis_names=self.experimental_shard_axis_names),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
def log_concave_rejection_sampler(
    mode,
    prob_fn,
    dtype,
    sample_shape=(),
    distribution_minimum=None,
    distribution_maximum=None,
    seed=None):
  """Utility for rejection sampling from log-concave discrete distributions.

  This utility constructs an easy-to-sample-from upper bound for a discrete
  univariate log-concave distribution (for discrete univariate distributions, a
  necessary and sufficient condition is p_k^2 >= p_{k-1} p_{k+1} for all k).
  The method requires that the mode of the distribution is known. While a better
  method can likely be derived for any given distribution, this method is
  general and easy to implement. The expected number of iterations is bounded by
  4+m, where m is the probability of the mode. For details, see [(Devroye,
  1979)][1].

  Args:
    mode: Tensor, the mode[s] of the [batch of] distribution[s].
    prob_fn: Python callable, counts -> prob(counts).
    dtype: DType of the generated samples.
    sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
    distribution_minimum: Tensor of type `dtype`. The minimum value
      taken by the distribution. The `prob` method will only be called on values
      greater than equal to the specified minimum. The shape must broadcast with
      the batch shape of the distribution. If unspecified, the domain is treated
      as unbounded below.
    distribution_maximum: Tensor of type `dtype`. The maximum value
      taken by the distribution. See `distribution_minimum` for details.
    seed: Python integer or `Tensor` instance, for seeding PRNG.

  Returns:
    samples: a `Tensor` with prepended dimensions `sample_shape`.

  #### References

  [1] Luc Devroye. A Simple Generator for Discrete Log-Concave
      Distributions. Computing, 1987.
  """
  mode = tf.broadcast_to(
      mode, ps.concat([sample_shape, ps.shape(mode)], axis=0))

  mode_height = prob_fn(mode)
  mode_shape = ps.shape(mode)

  top_width = 1. + mode_height / 2.  # w in ref [1].
  top_fraction = top_width / (1 + top_width)
  exponential_distribution = exponential.Exponential(
      rate=tf.ones([], dtype=dtype))  # E in ref [1].

  if distribution_minimum is None:
    distribution_minimum = tf.constant(-np.inf, dtype)
  if distribution_maximum is None:
    distribution_maximum = tf.constant(np.inf, dtype)

  def proposal(seed):
    """Proposal for log-concave rejection sampler."""
    (top_lobe_fractions_seed,
     exponential_samples_seed,
     top_selector_seed,
     rademacher_seed) = samplers.split_seed(seed, n=4)

    top_lobe_fractions = samplers.uniform(
        mode_shape, seed=top_lobe_fractions_seed, dtype=dtype)  # V in ref [1].
    top_offsets = top_lobe_fractions * top_width / mode_height

    exponential_samples = exponential_distribution.sample(
        mode_shape, seed=exponential_samples_seed)  # E in ref [1].
    exponential_height = (exponential_distribution.prob(exponential_samples) *
                          mode_height)
    exponential_offsets = (top_width + exponential_samples) / mode_height

    top_selector = samplers.uniform(
        mode_shape, seed=top_selector_seed, dtype=dtype)  # U in ref [1].
    on_top_mask = (top_selector <= top_fraction)

    unsigned_offsets = tf.where(on_top_mask, top_offsets, exponential_offsets)
    offsets = tf.round(
        tfp_random.rademacher(
            mode_shape, seed=rademacher_seed, dtype=dtype) *
        unsigned_offsets)

    potential_samples = mode + offsets
    envelope_height = tf.where(on_top_mask, mode_height, exponential_height)

    return potential_samples, envelope_height

  def target(values):
    # Check for out of bounds rather than in bounds to avoid accidentally
    # masking a `nan` value.
    out_of_bounds_mask = (
        (values < distribution_minimum) | (values > distribution_maximum))
    in_bounds_values = tf.where(
        out_of_bounds_mask, tf.constant(0., dtype=values.dtype), values)
    probs = prob_fn(in_bounds_values)
    return tf.where(out_of_bounds_mask, tf.zeros([], probs.dtype), probs)

  return tf.stop_gradient(
      brs.batched_rejection_sampler(
          proposal, target, seed, dtype=dtype)[0])  # Discard `num_iters`.
Example #6
0
    def append_losses(self, outputs, self_supervised_features=None):
        """Compute losses from outputs and append to self._losses_dict."""
        # Aliases.
        o = outputs
        f = self_supervised_features

        # Unsupervised losses.
        if f is None:
            # Sinusoidal autoencoder loss.
            for loss_obj in self.audio_loss_objs:
                name = 'sin_{}'.format(loss_obj.name)
                self._losses_dict[name] = loss_obj(o['audio'], o['sin_audio'])

            if self.harmonic_encoder is not None:
                # Add prior regularization on harmonic distribution.
                hdp = self.harmonic_distribution_prior
                if hdp is not None:
                    self._losses_dict.update({hdp.name: hdp(o['harm_dist'])})

                # Harmonic autoencoder loss.
                for loss_obj in self.audio_loss_objs:
                    name = 'harm_{}'.format(loss_obj.name)
                    self._losses_dict[name] = loss_obj(o['audio'],
                                                       o['harm_audio'])

                # Sinusoidal<->Harmonic consistency loss.
                if self.sinusoidal_consistency_losses:
                    sin_amps = o['sin_amps']
                    sin_freqs = o['sin_freqs']
                    if self.stop_gradient:
                        # Don't propagate harmonic errors to sinusoidal predictions.
                        sin_amps = tf.stop_gradient(sin_amps)
                        sin_freqs = tf.stop_gradient(sin_freqs)
                    for loss_obj in self.sinusoidal_consistency_losses:
                        self._losses_dict[loss_obj.name] = loss_obj(
                            sin_amps, sin_freqs, o['harm_amps'],
                            o['harm_freqs'])

            # Two-way mismatch loss between sinusoids and harmonics.
            if self.twm_loss is not None:
                if self.harmonic_encoder is not None:
                    loss = self.twm_loss(o['f0_hz'], o['sin_freqs'],
                                         o['sin_amps'])
                else:
                    loss = self.twm_loss(o['sin_freqs'], o['sin_freqs'],
                                         o['sin_amps'])
                self._losses_dict[self.twm_loss.name] = loss

        # Self-supervised Losses.
        else:
            # Sinusoidal self-supervision.
            if self.sinusoidal_consistency_losses:
                for loss_obj in self.sinusoidal_consistency_losses:
                    name = 'ss_' + loss_obj.name
                    self._losses_dict[name] = loss_obj(o['sin_amps'],
                                                       o['sin_freqs'],
                                                       f['sin_amps'],
                                                       f['sin_freqs'])

            # Filtered noise self-supervision.
            fncl = self.filtered_noise_consistency_loss
            if fncl is not None:
                name = 'ss_' + fncl.name
                self._losses_dict[name] = fncl(o['noise_magnitudes'],
                                               f['noise_magnitudes'])

            # Harmonic self-supervision.
            if self.harmonic_consistency_losses:
                for loss_obj in self.harmonic_consistency_losses:
                    if isinstance(loss_obj,
                                  ddsp.losses.HarmonicConsistencyLoss):
                        # L1 loss of harmonic synth controls.
                        losses = loss_obj(o['harm_amp'], f['harm_amp'],
                                          o['harm_dist'], f['harm_dist'],
                                          o['f0_hz'], f['f0_hz'])
                        losses = {'ss_' + k: v for k, v in losses.items()}
                        self._losses_dict.update(losses)
                    else:
                        # Same consistency loss as sinusoidal models.
                        name = 'ss_harm_' + loss_obj.name
                        self._losses_dict[name] = loss_obj(
                            o['harm_amp'], o['f0_hz'], f['harm_amp'],
                            f['f0_hz'])
 def _center_previous_state(x):
   # The empirical mean here is a stand-in for the true mean, so we drop the
   # gradient that flows through this term.
   return x - tf.stop_gradient(tf.reduce_mean(x, axis=batch_axes))
Example #8
0
  def one_step(self, current_state, previous_kernel_results, seed=None):
    with tf.name_scope(mcmc_util.make_name(self.name, 'phmc', 'one_step')):
      if self._store_parameters_in_results:
        step_size = previous_kernel_results.step_size
        num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
        momentum_distribution = previous_kernel_results.momentum_distribution
      else:
        step_size = self.step_size
        num_leapfrog_steps = self.num_leapfrog_steps
        momentum_distribution = self.momentum_distribution

      [
          current_state_parts,
          step_sizes,
          momentum_distribution,
          current_target_log_prob,
          current_target_log_prob_grad_parts,
      ] = _prepare_args(
          self.target_log_prob_fn,
          current_state,
          step_size,
          momentum_distribution,
          previous_kernel_results.target_log_prob,
          previous_kernel_results.grads_target_log_prob,
          maybe_expand=True,
          state_gradients_are_stopped=self.state_gradients_are_stopped)

      seed = samplers.sanitize_seed(seed)
      current_momentum_parts = list(momentum_distribution.sample(seed=seed))
      momentum_log_prob = getattr(momentum_distribution,
                                  '_log_prob_unnormalized',
                                  momentum_distribution.log_prob)
      kinetic_energy_fn = lambda *args: -momentum_log_prob(*args)

      integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
          self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

      [
          next_momentum_parts,
          next_state_parts,
          next_target_log_prob,
          next_target_log_prob_grad_parts,
      ] = integrator(
          current_momentum_parts,
          current_state_parts,
          target=current_target_log_prob,
          target_grad_parts=current_target_log_prob_grad_parts,
          kinetic_energy_fn=kinetic_energy_fn)
      if self.state_gradients_are_stopped:
        next_state_parts = [tf.stop_gradient(x) for x in next_state_parts]

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]
      new_kernel_results = previous_kernel_results._replace(
          log_acceptance_correction=_compute_log_acceptance_correction(
              kinetic_energy_fn, current_momentum_parts,
              next_momentum_parts),
          target_log_prob=next_target_log_prob,
          grads_target_log_prob=next_target_log_prob_grad_parts,
          initial_momentum=current_momentum_parts,
          final_momentum=next_momentum_parts,
          seed=seed,
      )

      return maybe_flatten(next_state_parts), new_kernel_results
Example #9
0
 def __call__(self, x):
     p = _sigmoid(x / self.alpha)
     k_sign = tf.sign(p - tf.random.uniform(tf.shape(p)))
     k_sign += (1.0 - tf.abs(k_sign))
     return x + tf.stop_gradient(-x + self.alpha * (k_sign + 1.0) / 2.0)
Example #10
0
 def __call__(self, x):
     if self.use_stochastic_rounding:
         x = _round_through(
             x, use_stochastic_rounding=self.use_stochastic_rounding)
     return x + tf.stop_gradient(-x + self.alpha * tf.where(
         tf.abs(x) < self.threshold, tf.zeros_like(x), tf.sign(x)))
Example #11
0
def _ceil_through(x):
    """Computes the ceiling operation using straight through estimator."""

    return x + tf.stop_gradient(-x + tf.ceil(x))
Example #12
0
 def _fn(x):
   # We'll make the gradient be `1` regardless of input.
   return f_x + (x - tf.stop_gradient(x))
Example #13
0
def contrastive_loss(features,
                     labels=None,
                     temperature=1.0,
                     contrast_mode=enums.LossContrastMode.ALL_VIEWS,
                     summation_location=enums.LossSummationLocation.OUTSIDE,
                     denominator_mode=enums.LossDenominatorMode.ALL,
                     positives_cap=-1,
                     scale_by_temperature=True):
    r"""Contrastive loss over features.

  Implemented as described in: https://arxiv.org/abs/2004.11362, Equation 2.

  Given `num_views` different views of each of `batch_size` samples, let `f_i`
  (i \in [1, 2 ... (num_views * batch_size)]) denote each respective feature
  vector. The contrastive loss then takes the following form:

    L = \sum_{i} L_i

  where each L_i is computed as:

    L_i = -\tau * \sum_{k \in P(i)} \log(p_{ik})    (1)

  where P(i) is the set of positives for entry i (distinct from i) and where:

                       \exp(f_i^T f_k / \tau)
    p_{ik} = ----------------------------------------                        (2)
             \sum_{j \in A(i)} \exp(f_i^T f_j / \tau)

  where A(i) is the set of all positives or negatives (distinct from i). `i` is
  the anchor, and \tau is the temperature.

  This maximizes the likelihood of a given (anchor, positive) pair with
  respect to all possible pairs where the first member is the anchor and the
  second member is a positive or a negative.

  A typical way to define a positive is to define samples from the
  same class (but not the anchor itself) regardless of what view they are from.
  Similarly, a typical way to define a negative is for it to be any view of a
  sample from a different class.

  There are two ways to define which feature pairs should be treated as
  positives and negatives. All views of the same sample are always treated as
  positives. You can declare other samples to be positives by providing `labels`
  such that all samples with the same label will be positives for each other.

  If `labels` is not provided then we default to every sample belonging to its
  own unique class. Therefore, the only positive used is another view of the
  anchor itself. This implements the loss as described in:

    https://arxiv.org/pdf/2002.05709.pdf
    A Simple Framework for Contrastive Learning of Visual Representations
    Chen T., Kornblith S., Norouzi M., Hinton G.

  It is recommended to use features whose L_2 norm is 1. since that ensures
  that the loss does not return NaN values without changing the intended
  behaviour of the loss function.

  In (1) above, note that the summation over positives is located outside of the
  \log(). However, one can permute these two operations. The result is Eq. 3 in
  https://arxiv.org/abs/2004.11362. Users can specify the location of the
  summation relative to the \log() via the `summation_location' argmument:
   - 'out': Eq. 2 in https://arxiv.org/abs/2004.11362.
   - 'in' : Eq. 3 in https://arxiv.org/abs/2004.11362.

  Additionally, in (2) above, note that the denominator sums over *all* entries
  distinct from i. One can change which terms are included in the denominator
  via the `denominator_mode` argument:
   - LossDenominatorMode.ALL : All entries (i.e., all negatives and all
             positives) distinct from i are included.
   - LossDenominatorMode.ONE_POSITIVE : All negatives are included but only the
             single positive in the numerator of (2) is included. Any other
             positives are excluded.
   - LossDenominatorMode.ONLY_NEGATIVES: All negatives are included but no
             positives are, not even the single positive in the numerator of
             (2).

  On TPUs, this method will internally perform the cross-replica operations that
  enable using the samples from all cores in computing the loss. The inputs to
  this function should be the features and labels from a single core and each
  core will compute the loss using just these features as anchors, but will use
  positives and negatives from the full global batch. Since the loss for each
  anchor is only computed on one TPU core, it's still necessary to have a
  cross-replica reduction in the final loss computation.

  Also, though it is not applicable to multiview contrastive learning, this
  function will work if |features| contains only 1 view. In the high batch size
  limit, the implemented contrastive loss with only 1 view, positives_cap = 1,
  and temperature = 1.0 is equivalent to the N-pairs loss
  (https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective.pdf)

  Args:
    features: A Tensor of rank at least 3, where the first 2 dimensions are
      batch_size and num_views, and the remaining dimensions are the feature
      shape. Note that when running on TPU, batch_size is the per-core batch
      size.
    labels: One-hot labels to be used to construct the supervised contrastive
      loss. Samples with the same labels are used as positives for each other.
      Labels must have shape [batch_size, num_labels] with numeric dtype and be
      0-1 valued. Note that when running on TPU, batch_size is the per-core
      batch size.
    temperature: Temperature at which softmax evaluation is done. Temperature
      must be a python scalar or scalar Tensor of numeric dtype.
    contrast_mode: LossContrastMode specifying which views get used as anchors
      (f_i in the expression above)
      'ALL_VIEWS': All the views of all samples are used as anchors (f_i in the
        expression above).
      'ONE_VIEW': Just the first view of each sample is used as an anchor (f_i
        in the expression above). This view is called the `core` view against
        which other views are contrasted.
    summation_location: LossSummationLocation specifying location of positives
      summation. See documentation above for more details.
    denominator_mode: LossDenominatorMode specifying which positives to include
      in contrastive denominator. See documentation above for more details.
    positives_cap: Integer maximum number of positives *other* than
      augmentations of anchor. Infinite if < 0. Must be multiple of num_views.
      Including augmentations, a maximum of (positives_cap + num_views - 1)
      positives is possible. This parameter modifies the contrastive numerator
      by selecting which positives are present in the summation, and which
      positives contribure to the denominator if denominator_mode ==
      enums.LossDenominatorMode.ALL.
    scale_by_temperature: Boolean. Whether to scale the loss by `temperature`.
      The loss gradient naturally has a 1/temperature scaling factor, so this
      counteracts it.

  Returns:
    Scalar tensor with contrastive loss value with shape [batch_size] and dtype
    tf.float32. The loss for each batch element is the mean over all views.

  Raises:
    ValueError if the shapes of any of the Tensors are unexpected, or if both
    `labels` and `mask` are not `None`.
  """
    features = tf.convert_to_tensor(features)
    labels = tf.convert_to_tensor(labels) if labels is not None else None

    local_batch_size, num_views = _validate_contrastive_loss_inputs(
        features, labels, contrast_mode, summation_location, denominator_mode,
        positives_cap)

    # Flatten `features` to a single dimension per view per sample so it has shape
    # [local_batch_size, num_views, num_features].
    if features.shape.rank > 3:
        features = tf.reshape(
            features, tf.concat([tf.shape(features)[:2], [-1]], axis=0),
            'flattened_features')
    if features.dtype != tf.float32:
        features = tf.cast(features, tf.float32)

    # Grab the features from all TPU cores. We use the local batch as anchors and
    # the full global batch as contrastives. If not on TPU, global_features is the
    # same as features.
    global_features = utils.cross_replica_concat(features)
    global_batch_size = tf.compat.dimension_at_index(global_features.shape,
                                                     0).value
    local_replica_id = utils.local_tpu_replica_id()

    # Generate the [local_batch_size, global_batch_size] slice of the
    # [global_batch_size, global_batch_size] identity matrix that corresponds to
    # the current replica.
    diagonal_mask = tf.one_hot(
        tf.range(local_batch_size) + (local_replica_id * local_batch_size),
        global_batch_size)

    # Generate `mask` with shape [local_batch_size, global_batch_size] that
    # indicates which samples should be considered positives for each other.
    if labels is None:
        # Defaults to every sample belonging to its own unique class, containing
        # just that sample and other views of it.
        mask = diagonal_mask
    else:
        labels = tf.cast(labels,
                         tf.float32)  # TPU matmul op unsupported for ints.
        global_labels = utils.cross_replica_concat(labels)
        mask = tf.linalg.matmul(labels, global_labels, transpose_b=True)
    mask = tf.ensure_shape(mask, [local_batch_size, global_batch_size])

    # To streamline the subsequent TF, the first two dimensions of
    # `global_features` (i.e., global_batch_size and num_views) should be
    # transposed and then flattened. The result has shape
    # [num_views * global_batch_size, num_features], and its first dimension
    # elements are grouped by view, not by sample.
    all_global_features = tf.reshape(
        tf.transpose(global_features, perm=[1, 0, 2]),
        [num_views * global_batch_size, -1])

    if contrast_mode == enums.LossContrastMode.ONE_VIEW:
        anchor_features = features[:, 0]
        num_anchor_views = 1
    else:  # contrast_mode == enums.LossContrastMode.ALL_VIEWS
        # Reshape features to match how global_features is reshaped above.
        anchor_features = tf.reshape(tf.transpose(features, perm=[1, 0, 2]),
                                     [num_views * local_batch_size, -1])
        num_anchor_views = num_views

    # Generate `logits`, the tensor of (temperature-scaled) dot products of the
    # anchor features with all features. It has shape
    # [local_batch_size * num_anchor_views, global_batch_size * num_views]. To
    # improve numerical stability, subtract out the largest |logits| element in
    # each row from all elements in that row. Since |logits| is only ever used as
    # a ratio of exponentials of |logits| values, this subtraction does not change
    # the results correctness. A stop_gradient() is needed because this change is
    # just for numerical precision.
    logits = tf.linalg.matmul(anchor_features,
                              all_global_features,
                              transpose_b=True)
    temperature = tf.cast(temperature, tf.float32)
    logits = logits / temperature
    logits = (logits -
              tf.reduce_max(tf.stop_gradient(logits), axis=1, keepdims=True))
    exp_logits = tf.exp(logits)

    # The following masks are all tiled by the number of views, i.e., they have
    # shape [local_batch_size * num_anchor_views, global_batch_size * num_views].
    positives_mask, negatives_mask = (_create_tiled_masks(
        mask, diagonal_mask, num_views, num_anchor_views, positives_cap))
    num_positives_per_row = tf.reduce_sum(positives_mask, axis=1)

    if denominator_mode == enums.LossDenominatorMode.ALL:
        denominator = tf.reduce_sum(
            exp_logits * negatives_mask, axis=1,
            keepdims=True) + tf.reduce_sum(
                exp_logits * positives_mask, axis=1, keepdims=True)
    elif denominator_mode == enums.LossDenominatorMode.ONE_POSITIVE:
        denominator = exp_logits + tf.reduce_sum(
            exp_logits * negatives_mask, axis=1, keepdims=True)
    else:  # denominator_mode == enums.LossDenominatorMode.ONLY_NEGATIVES
        denominator = tf.reduce_sum(exp_logits * negatives_mask,
                                    axis=1,
                                    keepdims=True)

    # Note that num_positives_per_row can be zero only if 1 view is used. The
    # various tf.math.divide_no_nan() calls below are to handle this case.
    if summation_location == enums.LossSummationLocation.OUTSIDE:
        log_probs = (logits - tf.math.log(denominator)) * positives_mask
        log_probs = tf.reduce_sum(log_probs, axis=1)
        log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row)
    else:  # summation_location == enums.LossSummationLocation.INSIDE
        log_probs = exp_logits / denominator * positives_mask
        log_probs = tf.reduce_sum(log_probs, axis=1)
        log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row)
        log_probs = tf.math.log(log_probs)

    loss = -log_probs
    if scale_by_temperature:
        loss *= temperature
    loss = tf.reshape(loss, [num_anchor_views, local_batch_size])

    if num_views != 1:
        loss = tf.reduce_mean(loss, axis=0)
    else:
        # The 1 view case requires special handling bc, unlike in the > 1 view case,
        # not all samples are guaranteed to have a positive. Also, no reduction over
        # views is needed.
        num_valid_views_per_sample = (tf.reshape(num_positives_per_row,
                                                 [1, local_batch_size]))
        loss = tf.squeeze(
            tf.math.divide_no_nan(loss, num_valid_views_per_sample))

    return loss
Example #14
0
 def fn(x, y):
     return x**2 + tf.stop_gradient(y)**2
Example #15
0
 def _center_previous_state(x):
     # The empirical mean here is a stand-in for the true mean, so we drop the
     # gradient that flows through this term.
     x_mean = _reduce_mean_with_axes(x, batch_axes, reduce_chain_axis_names)
     return x - tf.stop_gradient(x_mean)
Example #16
0
  def one_step(self, current_state, previous_kernel_results):
    with tf.name_scope(
        mcmc_util.make_name(self.name, 'hmc', 'one_step')):
      if self._store_parameters_in_results:
        step_size = previous_kernel_results.step_size
        num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
      else:
        step_size = self.step_size
        num_leapfrog_steps = self.num_leapfrog_steps

      [
          current_state_parts,
          step_sizes,
          current_target_log_prob,
          current_target_log_prob_grad_parts,
      ] = _prepare_args(
          self.target_log_prob_fn,
          current_state,
          step_size,
          previous_kernel_results.target_log_prob,
          previous_kernel_results.grads_target_log_prob,
          maybe_expand=True,
          state_gradients_are_stopped=self.state_gradients_are_stopped)

      current_momentum_parts = []
      for x in current_state_parts:
        current_momentum_parts.append(
            tf.random.normal(
                shape=tf.shape(input=x),
                dtype=self._momentum_dtype or x.dtype.base_dtype,
                seed=self._seed_stream()))

      integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
          self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

      [
          next_momentum_parts,
          next_state_parts,
          next_target_log_prob,
          next_target_log_prob_grad_parts,
      ] = integrator(current_momentum_parts,
                     current_state_parts,
                     current_target_log_prob,
                     current_target_log_prob_grad_parts)
      if self.state_gradients_are_stopped:
        next_state_parts = [tf.stop_gradient(x) for x in next_state_parts]

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      independent_chain_ndims = distribution_util.prefer_static_rank(
          current_target_log_prob)

      new_kernel_results = previous_kernel_results._replace(
          log_acceptance_correction=_compute_log_acceptance_correction(
              current_momentum_parts, next_momentum_parts,
              independent_chain_ndims),
          target_log_prob=next_target_log_prob,
          grads_target_log_prob=next_target_log_prob_grad_parts,
      )

      return maybe_flatten(next_state_parts), new_kernel_results
Example #17
0
def regression_loss(logits, labels, num_steps, steps, seq_lens, loss_type,
                    normalize_indices, variance_lambda, huber_delta):
  """Loss function based on regressing to the correct indices.

  In the paper, this is called Cycle-back Regression. There are 3 variants
  of this loss:
  i) regression_mse: MSE of the predicted indices and ground truth indices.
  ii) regression_mse_var: MSE of the predicted indices that takes into account
  the variance of the similarities. This is important when the rate at which
  sequences go through different phases changes a lot. The variance scaling
  allows dynamic weighting of the MSE loss based on the similarities.
  iii) regression_huber: Huber loss between the predicted indices and ground
  truth indices.


  Args:
    logits: Tensor, Pre-softmax similarity scores after cycling back to the
      starting sequence.
    labels: Tensor, One hot labels containing the ground truth. The index where
      the cycle started is 1.
    num_steps: Integer, Number of steps in the sequence embeddings.
    steps: Tensor, step indices/frame indices of the embeddings of the shape
      [N, T] where N is the batch size, T is the number of the timesteps.
    seq_lens: Tensor, Lengths of the sequences from which the sampling was done.
      This can provide additional temporal information to the alignment loss.
    loss_type: String, This specifies the kind of regression loss function.
      Currently supported loss functions: regression_mse, regression_mse_var,
      regression_huber.
    normalize_indices: Boolean, If True, normalizes indices by sequence lengths.
      Useful for ensuring numerical instabilities don't arise as sequence
      indices can be large numbers.
    variance_lambda: Float, Weight of the variance of the similarity
      predictions while cycling back. If this is high then the low variance
      similarities are preferred by the loss while making this term low results
      in high variance of the similarities (more uniform/random matching).
    huber_delta: float, Huber delta described in tf.keras.losses.huber_loss.

  Returns:
     loss: Tensor, A scalar loss calculated using a variant of regression.
  """
  # Just to be safe, we stop gradients from labels as we are generating labels.
  labels = tf.stop_gradient(labels)
  steps = tf.stop_gradient(steps)

  if normalize_indices:
    float_seq_lens = tf.cast(seq_lens, tf.float32)
    tile_seq_lens = tf.tile(
        tf.expand_dims(float_seq_lens, axis=1), [1, num_steps])
    steps = tf.cast(steps, tf.float32) / tile_seq_lens
  else:
    steps = tf.cast(steps, tf.float32)

  beta = tf.nn.softmax(logits)
  true_time = tf.reduce_sum(steps * labels, axis=1)
  pred_time = tf.reduce_sum(steps * beta, axis=1)

  if loss_type in ['regression_mse', 'regression_mse_var']:
    if 'var' in loss_type:
      # Variance aware regression.
      pred_time_tiled = tf.tile(tf.expand_dims(pred_time, axis=1),
                                [1, num_steps])

      pred_time_variance = tf.reduce_sum(
          tf.square(steps - pred_time_tiled) * beta, axis=1)

      # Using log of variance as it is numerically stabler.
      pred_time_log_var = tf.math.log(pred_time_variance)
      squared_error = tf.square(true_time - pred_time)
      return tf.reduce_mean(tf.math.exp(-pred_time_log_var) * squared_error
                            + variance_lambda * pred_time_log_var)

    else:
      return tf.reduce_mean(
          tf.keras.losses.mean_squared_error(y_true=true_time,
                                             y_pred=pred_time))
  elif loss_type == 'regression_huber':
    return tf.reduce_mean(tf.keras.losses.huber_loss(
        y_true=true_time, y_pred=pred_time,
        delta=huber_delta))
  else:
    raise ValueError('Unsupported regression loss %s. Supported losses are: '
                     'regression_mse, regresstion_mse_var and regression_huber.'
                     % loss_type)
    def testDistribution(self, dist_name, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        seed = tfp_test_util.test_seed()
        dist = data.draw(distributions(dist_name=dist_name, enable_vars=True))
        batch_shape = dist.batch_shape
        batch_shape2 = data.draw(
            tfp_hps.broadcast_compatible_shape(batch_shape))
        dist2 = data.draw(
            distributions(dist_name=dist_name,
                          batch_shape=batch_shape2,
                          event_dim=get_event_dim(dist),
                          enable_vars=True))
        self.evaluate([var.initializer for var in dist.variables])

        # Check that the distribution passes Variables through to the accessor
        # properties (without converting them to Tensor or anything like that).
        for k, v in six.iteritems(dist.parameters):
            if not tensor_util.is_ref(v):
                continue
            self.assertIs(getattr(dist, k), v)

        # Check that standard statistics do not read distribution parameters more
        # than twice (once in the stat itself and up to once in any validation
        # assertions).
        for stat in data.draw(
                hps.sets(hps.one_of(
                    map(hps.just, [
                        'covariance', 'entropy', 'mean', 'mode', 'stddev',
                        'variance'
                    ])),
                         min_size=3,
                         max_size=3)):
            hp.note('Testing excessive var usage in {}.{}'.format(
                dist_name, stat))
            try:
                with tfp_hps.assert_no_excessive_var_usage(
                        'statistic `{}` of `{}`'.format(stat, dist)):
                    getattr(dist, stat)()

            except NotImplementedError:
                pass

        # Check that `sample` doesn't read distribution parameters more than twice,
        # and that it produces non-None gradients (if the distribution is fully
        # reparameterized).
        with tf.GradientTape() as tape:
            # TDs do bijector assertions twice (once by distribution.sample, and once
            # by bijector.forward).
            max_permissible = (3 if isinstance(
                dist, tfd.TransformedDistribution) else 2)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `sample` of `{}`'.format(dist),
                    max_permissible=max_permissible):
                sample = dist.sample(seed=seed)
        if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
            grads = tape.gradient(sample, dist.variables)
            for grad, var in zip(grads, dist.variables):
                var_name = var.name.rstrip('_0123456789:')
                if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()):
                    continue
                if grad is None:
                    raise AssertionError(
                        'Missing sample -> {} grad for distribution {}'.format(
                            var_name, dist_name))

        # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's
        # own samples.  Also, to relax conversion counts for KL (might do >2 w/
        # validate_args).
        dist = dist.copy(validate_args=False)
        dist2 = dist2.copy(validate_args=False)

        # Test that KL divergence reads distribution parameters at most once, and
        # that is produces non-None gradients.
        try:
            for d1, d2 in (dist, dist2), (dist2, dist):
                with tf.GradientTape() as tape:
                    with tfp_hps.assert_no_excessive_var_usage(
                            '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'
                            .format(d1, d1.variables, d2, d2.variables),
                            max_permissible=1
                    ):  # No validation => 1 convert per var.
                        kl = d1.kl_divergence(d2)
                wrt_vars = list(d1.variables) + list(d2.variables)
                grads = tape.gradient(kl, wrt_vars)
                for grad, var in zip(grads, wrt_vars):
                    if grad is None and dist_name not in NO_KL_PARAM_GRADS:
                        raise AssertionError(
                            'Missing KL({} || {}) -> {} grad:\n'
                            '{} vars: {}\n{} vars: {}'.format(
                                d1, d2, var, d1, d1.variables, d2,
                                d2.variables))
        except NotImplementedError:
            pass

        # Test that log_prob produces non-None gradients, except for distributions
        # on the NO_LOG_PROB_PARAM_GRADS blacklist.
        if dist_name not in NO_LOG_PROB_PARAM_GRADS:
            with tf.GradientTape() as tape:
                lp = dist.log_prob(tf.stop_gradient(sample))
            grads = tape.gradient(lp, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing log_prob -> {} grad for distribution {}'.
                        format(var, dist_name))

        # Test that all forms of probability evaluation avoid reading distribution
        # parameters more than once.
        for evaluative in data.draw(
                hps.sets(hps.one_of(
                    map(hps.just, [
                        'log_prob', 'prob', 'log_cdf', 'cdf',
                        'log_survival_function', 'survival_function'
                    ])),
                         min_size=3,
                         max_size=3)):
            hp.note('Testing excessive var usage in {}.{}'.format(
                dist_name, evaluative))
            try:
                # No validation => 1 convert. But for TD we allow 2:
                # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp)
                max_permissible = (2 if isinstance(
                    dist, tfd.TransformedDistribution) else 1)
                with tfp_hps.assert_no_excessive_var_usage(
                        'evaluative `{}` of `{}`'.format(evaluative, dist),
                        max_permissible=max_permissible):
                    getattr(dist, evaluative)(sample)
            except NotImplementedError:
                pass
Example #19
0
  def call(self, x, training=False):
    x_flat = tf.reshape(x, shape=(-1, self.depth))

    # Split each input vector into one segment per head.
    x_flat_split = tf.split(x_flat, self.num_heads, axis=1)
    x_flat = tf.concat(x_flat_split, axis=0)

    if training:
      # Figure out which centroids we want to keep, and which we want to
      # restart.
      n = x_flat.shape[0]
      keep = self.counts * self.k > self.restart_threshold * n
      restart = tf.math.logical_not(keep)

      # Replace centroids to restart with elements from the batch, using samples
      # from a uniform distribution as a fallback in case we need to restart
      # more centroids than we have elements in the batch.
      restart_idx = tf.squeeze(tf.where(restart), -1)
      n_replace = tf.minimum(tf.shape(restart_idx)[0], x_flat.shape[0])
      e_restart = tf.tensor_scatter_nd_update(
          tf.random.uniform([self.k, self.depth // self.num_heads]),
          tf.expand_dims(restart_idx[:n_replace], 1),
          tf.random.shuffle(x_flat)[:n_replace]
      )

      # Compute the values of the centroids we want to keep by dividing the
      # summed vectors by the corresponding counts.
      e = tf.where(
          tf.expand_dims(keep, 1),
          tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)),
          e_restart
      )

    else:
      # If not training, just use the centroids as is with no restarts.
      e = tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1))

    # Compute distance between each input vector and each cluster center.
    distances = (
        tf.expand_dims(tf.reduce_sum(x_flat**2, axis=1), 1) -
        2 * tf.matmul(x_flat, tf.transpose(e)) +
        tf.expand_dims(tf.reduce_sum(e**2, axis=1), 0)
    )

    # Find nearest cluster center for each input vector.
    c = tf.argmin(distances, axis=1)

    # Quantize input vectors with straight-through estimator.
    z = tf.nn.embedding_lookup(e, c)
    z_split = tf.split(z, self.num_heads, axis=0)
    z = tf.concat(z_split, axis=1)
    z = tf.reshape(z, tf.shape(x))
    z = x + tf.stop_gradient(z - x)

    if training:
      # Compute cluster counts and vector sums over the batch.
      oh = tf.one_hot(indices=c, depth=self.k)
      counts = tf.reduce_sum(oh, axis=0)
      sums = tf.matmul(oh, x_flat, transpose_a=True)

      # Apply exponential moving average to cluster counts and vector sums.
      self.counts.assign_sub((1 - self.gamma) * (self.counts - counts))
      self.sums.assign_sub((1 - self.gamma) * (self.sums - sums))

    c_split = tf.split(c, self.num_heads, axis=0)
    c = tf.stack(c_split, axis=1)
    c = tf.reshape(c, tf.concat([tf.shape(x)[:-1], [self.num_heads]], axis=0))

    return z, c
    def _reparameterize_sample(self, x, event_shape):
        """Adds reparameterization (pathwise) gradients to samples of the mixture.

    Implicit reparameterization gradients are
       dx/dphi = -(d transform(x, phi) / dx)^-1 * d transform(x, phi) / dphi,
    where transform(x, phi) is distributional transform that removes all
    parameters from samples x.

    We implement them by replacing x with
      -stop_gradient(d transform(x, phi) / dx)^-1 * transform(x, phi)]
    for the backward pass (gradient computation).
    The derivative of this quantity w.r.t. phi is then the implicit
    reparameterization gradient.
    Note that this replaces the gradients w.r.t. both the mixture
    distribution parameters and components distributions parameters.

    Limitations:
      1. Fundamental: components must be fully reparameterized.
      2. Distributional transform is currently only implemented for
        factorized components.
      3. Distributional transform currently only works for known rank of the
        batch tensor.

    Args:
      x: Sample of mixture distribution
      event_shape: The event shape of this distribution

    Returns:
      Tensor with same value as x, but with reparameterization gradients
    """
        # Remove the existing gradients of x wrt parameters of the components.
        x = tf.stop_gradient(x)

        event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32)
        x_2d_shape = [-1, event_size]  # [S*prod(B), prod(E)]

        # Perform distributional transform of x in [S, B, E] shape,
        # but have Jacobian of size [S*prod(B), prod(E), prod(E)].
        def reshaped_distributional_transform(x_2d):
            return tf.reshape(
                self._distributional_transform(tf.reshape(x_2d, ps.shape(x)),
                                               event_shape), x_2d_shape)

        # transform_2d: [S*prod(B), prod(E)]
        # jacobian: [S*prod(B), prod(E), prod(E)]
        x_2d = tf.reshape(x, x_2d_shape)
        transform_2d, jacobian = value_and_batch_jacobian(
            reshaped_distributional_transform, x_2d)

        # We only provide the first derivative; the second derivative computed by
        # autodiff would be incorrect, so we raise an error if it is requested.
        transform_2d = _prevent_2nd_derivative(transform_2d)

        # Compute [- stop_gradient(jacobian)^-1 * transform] by solving a linear
        # system. The Jacobian is lower triangular because the distributional
        # transform for i-th event dimension does not depend on the next
        # dimensions.
        surrogate_x_2d = -tf.linalg.triangular_solve(
            tf.stop_gradient(jacobian),
            transform_2d[..., tf.newaxis],
            lower=True)  # [S*prod(B), prod(E), 1]
        surrogate_x = tf.reshape(surrogate_x_2d, ps.shape(x))

        # Replace gradients of x with gradients of surrogate_x, but keep the value.
        return x + (surrogate_x - tf.stop_gradient(surrogate_x))
Example #21
0
    def forward(self, features, training=True):
        """Run forward pass of model (no losses) on a dictionary of features."""
        # Audio -> Sinusoids -------------------------------------------------------
        audio = features['audio']

        # Encode the data from audio to sinusoids.
        pg_in = self.sinusoidal_encoder(features, training=training)

        # Manually apply the scaling nonlinearities.
        sin_freqs = self.freq_scale_fn(pg_in['frequencies'])
        sin_amps = self.amps_scale_fn(pg_in['amplitudes'])
        noise_magnitudes = self.amps_scale_fn(pg_in['noise_magnitudes'])
        pg_in['frequencies'] = sin_freqs
        pg_in['amplitudes'] = sin_amps
        pg_in['noise_magnitudes'] = noise_magnitudes

        # Reconstruct sinusoidal audio.
        sin_audio = self.processor_group(pg_in)

        outputs = {
            # Input signal.
            'audio': audio,
            # Filtered noise signal.
            'noise_magnitudes': noise_magnitudes,
            # Sinusoidal signal.
            'sin_audio': sin_audio,
            'sin_amps': sin_amps,
            'sin_freqs': sin_freqs,
        }

        # Sinusoids -> Harmonics ---------------------------------------------------
        # Encode the sinusoids into a harmonics.
        if self.stop_gradient:
            sin_freqs = tf.stop_gradient(sin_freqs)
            sin_amps = tf.stop_gradient(sin_amps)
            noise_magnitudes = tf.stop_gradient(noise_magnitudes)

        if self.harmonic_encoder is not None:
            harm_amp, harm_dist, f0_hz = self.harmonic_encoder(
                sin_freqs, sin_amps)

            # Decode harmonics back to sinusoids.
            n_harmonics = int(harm_dist.shape[-1])
            harm_freqs = ddsp.core.get_harmonic_frequencies(f0_hz, n_harmonics)
            harm_amps = harm_amp * harm_dist

            # Reconstruct harmonic audio.
            pg_in['frequencies'] = harm_freqs
            pg_in['amplitudes'] = harm_amps
            pg_in['noise_magnitudes'] = noise_magnitudes
            harm_audio = self.processor_group(pg_in)

            outputs.update({
                # Harmonic signal.
                'harm_audio': harm_audio,
                'harm_amp': harm_amp,
                'harm_dist': harm_dist,
                'f0_hz': f0_hz,
                # Harmonic Sinusoids.
                'harm_freqs': harm_freqs,
                'harm_amps': harm_amps,
            })

        return outputs