def update(self, expert_dataset_iter, policy_dataset_iter, discount, replay_regularization=0.05, nu_reg=10.0): """A function that updates nu network. When replay regularization is non-zero, it learns (d_pi * (1 - replay_regularization) + d_rb * replay_regulazation) / (d_expert * (1 - replay_regularization) + d_rb * replay_regulazation) instead. Args: expert_dataset_iter: An tensorflow graph iteratable over expert data. policy_dataset_iter: An tensorflow graph iteratable over training policy data, used for regularization. discount: An MDP discount. replay_regularization: A fraction of samples to add from a replay buffer. nu_reg: A grad penalty regularization coefficient. """ (expert_states, expert_actions, expert_next_states) = expert_dataset_iter.get_next() expert_initial_states = expert_states # rb_states, rb_actions, rb_next_states, _, _ = policy_dataset_iter.get_next( # )[0] with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self.actor.variables) tape.watch(self.nu_net.variables) _, policy_next_actions, _ = self.actor(expert_next_states) # _, rb_next_actions, rb_log_prob = self.actor(rb_next_states) _, policy_initial_actions, _ = self.actor(expert_initial_states) # Inputs for the linear part of DualDICE loss. expert_init_inputs = tf.concat( [expert_initial_states, policy_initial_actions], 1) if not self.discrete: expert_inputs = tf.concat([expert_states, expert_actions], 1) else: mat = tf.one_hot(tf.cast(expert_actions, tf.int32), depth=self.action_dim, axis=-1) expert_inputs = tf.concat([expert_states, mat], 1) expert_next_inputs = tf.concat( [expert_next_states, policy_next_actions], 1) # rb_inputs = tf.concat([rb_states, rb_actions], 1) # rb_next_inputs = tf.concat([rb_next_states, rb_next_actions], 1) expert_nu_0 = self.nu_net(expert_init_inputs) expert_nu = self.nu_net(expert_inputs) expert_nu_next = self.nu_net(expert_next_inputs) # rb_nu = self.nu_net(rb_inputs) # rb_nu_next = self.nu_net(rb_next_inputs) expert_diff = expert_nu - discount * expert_nu_next # rb_diff = rb_nu - discount * rb_nu_next linear_loss_expert = tf.reduce_mean(expert_nu_0 * (1 - discount)) # linear_loss_rb = tf.reduce_mean(rb_diff) rb_expert_diff = expert_diff #tf.concat([expert_diff, rb_diff], 0) rb_expert_weights = tf.ones(expert_diff.shape) #tf.concat([ # tf.ones(expert_diff.shape) * (1 - replay_regularization), # tf.ones(rb_diff.shape) * replay_regularization # ], 0) rb_expert_weights /= tf.reduce_sum(rb_expert_weights) non_linear_loss = tf.reduce_sum( tf.stop_gradient( weighted_softmax(rb_expert_diff, rb_expert_weights, axis=0)) * rb_expert_diff) linear_loss = (linear_loss_expert * (1 - replay_regularization) + 0) # linear_loss_rb * replay_regularization) loss = (non_linear_loss - linear_loss) alpha = tf.random.uniform(shape=(expert_inputs.shape[0], 1)) # nu_inter = alpha * expert_inputs + (1 - alpha) * expert_init_inputs #rb_inputs # nu_next_inter = alpha * expert_next_inputs + (1 - alpha) * #rb_next_inputs # nu_inter = tf.concat([nu_inter, nu_next_inter], 0) nu_inter = alpha * expert_inputs + (1 - alpha) * tf.stop_gradient( tf.random.shuffle(expert_next_inputs)) with tf.GradientTape(watch_accessed_variables=False) as tape2: tape2.watch(nu_inter) nu_output = self.nu_net(nu_inter) nu_grad = tape2.gradient(nu_output, [nu_inter])[0] + EPS nu_grad_penalty = tf.reduce_mean( tf.square(tf.norm(nu_grad, axis=-1, keepdims=True) - 1)) nu_loss = loss + nu_grad_penalty * nu_reg pi_loss = -loss + keras_utils.orthogonal_regularization( self.actor.trunk) nu_grads = tape.gradient(nu_loss, self.nu_net.variables) pi_grads = tape.gradient(pi_loss, self.actor.variables) self.nu_optimizer.apply_gradients(zip(nu_grads, self.nu_net.variables)) self.actor_optimizer.apply_gradients( zip(pi_grads, self.actor.variables)) del tape self.avg_nu_expert(expert_nu) #self.avg_nu_rb(rb_nu) self.nu_reg_metric(nu_grad_penalty) self.avg_loss(loss) self.avg_actor_loss(pi_loss) #self.avg_actor_entropy(-rb_log_prob) if tf.equal(self.nu_optimizer.iterations % self.log_interval, 0): tf.summary.scalar('train dual dice/loss', self.avg_loss.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_loss) tf.summary.scalar('train dual dice/nu expert', self.avg_nu_expert.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_nu_expert) tf.summary.scalar('train dual dice/nu rb', self.avg_nu_rb.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_nu_rb) tf.summary.scalar('train dual dice/nu reg', self.nu_reg_metric.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.nu_reg_metric) if tf.equal(self.actor_optimizer.iterations % self.log_interval, 0): tf.summary.scalar('train sac/actor_loss', self.avg_actor_loss.result(), step=self.actor_optimizer.iterations) keras_utils.my_reset_states(self.avg_actor_loss) tf.summary.scalar('train sac/actor entropy', self.avg_actor_entropy.result(), step=self.actor_optimizer.iterations) keras_utils.my_reset_states(self.avg_actor_entropy)
def _zeros_like(x): return x * tf.stop_gradient(x - 1.) - tf.stop_gradient(x * (x - 1.))
def committment_loss(self, z, z_q): """Encourage encoder to output embeddings close to the current centroids.""" loss = losses.mean_difference(z, tf.stop_gradient(z_q), loss_type='L2') return self.commitment_loss_weight * loss
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) seed = samplers.sanitize_seed(seed) # Retain for diagnostics. seeds = list(samplers.split_seed(seed, n=len(current_state_parts))) seeds = distribute_lib.fold_in_axis_index( seeds, self.experimental_shard_axis_names) current_momentum_parts = [] for part_seed, x in zip(seeds, current_state_parts): current_momentum_parts.append( samplers.normal(shape=ps.shape(x), dtype=self._momentum_dtype or dtype_util.base_dtype(x.dtype), seed=part_seed)) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = ps.rank(current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims, shard_axis_names=self.experimental_shard_axis_names), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, initial_momentum=current_momentum_parts, final_momentum=next_momentum_parts, seed=seed, ) return maybe_flatten(next_state_parts), new_kernel_results
def log_concave_rejection_sampler( mode, prob_fn, dtype, sample_shape=(), distribution_minimum=None, distribution_maximum=None, seed=None): """Utility for rejection sampling from log-concave discrete distributions. This utility constructs an easy-to-sample-from upper bound for a discrete univariate log-concave distribution (for discrete univariate distributions, a necessary and sufficient condition is p_k^2 >= p_{k-1} p_{k+1} for all k). The method requires that the mode of the distribution is known. While a better method can likely be derived for any given distribution, this method is general and easy to implement. The expected number of iterations is bounded by 4+m, where m is the probability of the mode. For details, see [(Devroye, 1979)][1]. Args: mode: Tensor, the mode[s] of the [batch of] distribution[s]. prob_fn: Python callable, counts -> prob(counts). dtype: DType of the generated samples. sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples. distribution_minimum: Tensor of type `dtype`. The minimum value taken by the distribution. The `prob` method will only be called on values greater than equal to the specified minimum. The shape must broadcast with the batch shape of the distribution. If unspecified, the domain is treated as unbounded below. distribution_maximum: Tensor of type `dtype`. The maximum value taken by the distribution. See `distribution_minimum` for details. seed: Python integer or `Tensor` instance, for seeding PRNG. Returns: samples: a `Tensor` with prepended dimensions `sample_shape`. #### References [1] Luc Devroye. A Simple Generator for Discrete Log-Concave Distributions. Computing, 1987. """ mode = tf.broadcast_to( mode, ps.concat([sample_shape, ps.shape(mode)], axis=0)) mode_height = prob_fn(mode) mode_shape = ps.shape(mode) top_width = 1. + mode_height / 2. # w in ref [1]. top_fraction = top_width / (1 + top_width) exponential_distribution = exponential.Exponential( rate=tf.ones([], dtype=dtype)) # E in ref [1]. if distribution_minimum is None: distribution_minimum = tf.constant(-np.inf, dtype) if distribution_maximum is None: distribution_maximum = tf.constant(np.inf, dtype) def proposal(seed): """Proposal for log-concave rejection sampler.""" (top_lobe_fractions_seed, exponential_samples_seed, top_selector_seed, rademacher_seed) = samplers.split_seed(seed, n=4) top_lobe_fractions = samplers.uniform( mode_shape, seed=top_lobe_fractions_seed, dtype=dtype) # V in ref [1]. top_offsets = top_lobe_fractions * top_width / mode_height exponential_samples = exponential_distribution.sample( mode_shape, seed=exponential_samples_seed) # E in ref [1]. exponential_height = (exponential_distribution.prob(exponential_samples) * mode_height) exponential_offsets = (top_width + exponential_samples) / mode_height top_selector = samplers.uniform( mode_shape, seed=top_selector_seed, dtype=dtype) # U in ref [1]. on_top_mask = (top_selector <= top_fraction) unsigned_offsets = tf.where(on_top_mask, top_offsets, exponential_offsets) offsets = tf.round( tfp_random.rademacher( mode_shape, seed=rademacher_seed, dtype=dtype) * unsigned_offsets) potential_samples = mode + offsets envelope_height = tf.where(on_top_mask, mode_height, exponential_height) return potential_samples, envelope_height def target(values): # Check for out of bounds rather than in bounds to avoid accidentally # masking a `nan` value. out_of_bounds_mask = ( (values < distribution_minimum) | (values > distribution_maximum)) in_bounds_values = tf.where( out_of_bounds_mask, tf.constant(0., dtype=values.dtype), values) probs = prob_fn(in_bounds_values) return tf.where(out_of_bounds_mask, tf.zeros([], probs.dtype), probs) return tf.stop_gradient( brs.batched_rejection_sampler( proposal, target, seed, dtype=dtype)[0]) # Discard `num_iters`.
def append_losses(self, outputs, self_supervised_features=None): """Compute losses from outputs and append to self._losses_dict.""" # Aliases. o = outputs f = self_supervised_features # Unsupervised losses. if f is None: # Sinusoidal autoencoder loss. for loss_obj in self.audio_loss_objs: name = 'sin_{}'.format(loss_obj.name) self._losses_dict[name] = loss_obj(o['audio'], o['sin_audio']) if self.harmonic_encoder is not None: # Add prior regularization on harmonic distribution. hdp = self.harmonic_distribution_prior if hdp is not None: self._losses_dict.update({hdp.name: hdp(o['harm_dist'])}) # Harmonic autoencoder loss. for loss_obj in self.audio_loss_objs: name = 'harm_{}'.format(loss_obj.name) self._losses_dict[name] = loss_obj(o['audio'], o['harm_audio']) # Sinusoidal<->Harmonic consistency loss. if self.sinusoidal_consistency_losses: sin_amps = o['sin_amps'] sin_freqs = o['sin_freqs'] if self.stop_gradient: # Don't propagate harmonic errors to sinusoidal predictions. sin_amps = tf.stop_gradient(sin_amps) sin_freqs = tf.stop_gradient(sin_freqs) for loss_obj in self.sinusoidal_consistency_losses: self._losses_dict[loss_obj.name] = loss_obj( sin_amps, sin_freqs, o['harm_amps'], o['harm_freqs']) # Two-way mismatch loss between sinusoids and harmonics. if self.twm_loss is not None: if self.harmonic_encoder is not None: loss = self.twm_loss(o['f0_hz'], o['sin_freqs'], o['sin_amps']) else: loss = self.twm_loss(o['sin_freqs'], o['sin_freqs'], o['sin_amps']) self._losses_dict[self.twm_loss.name] = loss # Self-supervised Losses. else: # Sinusoidal self-supervision. if self.sinusoidal_consistency_losses: for loss_obj in self.sinusoidal_consistency_losses: name = 'ss_' + loss_obj.name self._losses_dict[name] = loss_obj(o['sin_amps'], o['sin_freqs'], f['sin_amps'], f['sin_freqs']) # Filtered noise self-supervision. fncl = self.filtered_noise_consistency_loss if fncl is not None: name = 'ss_' + fncl.name self._losses_dict[name] = fncl(o['noise_magnitudes'], f['noise_magnitudes']) # Harmonic self-supervision. if self.harmonic_consistency_losses: for loss_obj in self.harmonic_consistency_losses: if isinstance(loss_obj, ddsp.losses.HarmonicConsistencyLoss): # L1 loss of harmonic synth controls. losses = loss_obj(o['harm_amp'], f['harm_amp'], o['harm_dist'], f['harm_dist'], o['f0_hz'], f['f0_hz']) losses = {'ss_' + k: v for k, v in losses.items()} self._losses_dict.update(losses) else: # Same consistency loss as sinusoidal models. name = 'ss_harm_' + loss_obj.name self._losses_dict[name] = loss_obj( o['harm_amp'], o['f0_hz'], f['harm_amp'], f['f0_hz'])
def _center_previous_state(x): # The empirical mean here is a stand-in for the true mean, so we drop the # gradient that flows through this term. return x - tf.stop_gradient(tf.reduce_mean(x, axis=batch_axes))
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'phmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps momentum_distribution = previous_kernel_results.momentum_distribution else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps momentum_distribution = self.momentum_distribution [ current_state_parts, step_sizes, momentum_distribution, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, momentum_distribution, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) seed = samplers.sanitize_seed(seed) current_momentum_parts = list(momentum_distribution.sample(seed=seed)) momentum_log_prob = getattr(momentum_distribution, '_log_prob_unnormalized', momentum_distribution.log_prob) kinetic_energy_fn = lambda *args: -momentum_log_prob(*args) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator( current_momentum_parts, current_state_parts, target=current_target_log_prob, target_grad_parts=current_target_log_prob_grad_parts, kinetic_energy_fn=kinetic_energy_fn) if self.state_gradients_are_stopped: next_state_parts = [tf.stop_gradient(x) for x in next_state_parts] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( kinetic_energy_fn, current_momentum_parts, next_momentum_parts), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, initial_momentum=current_momentum_parts, final_momentum=next_momentum_parts, seed=seed, ) return maybe_flatten(next_state_parts), new_kernel_results
def __call__(self, x): p = _sigmoid(x / self.alpha) k_sign = tf.sign(p - tf.random.uniform(tf.shape(p))) k_sign += (1.0 - tf.abs(k_sign)) return x + tf.stop_gradient(-x + self.alpha * (k_sign + 1.0) / 2.0)
def __call__(self, x): if self.use_stochastic_rounding: x = _round_through( x, use_stochastic_rounding=self.use_stochastic_rounding) return x + tf.stop_gradient(-x + self.alpha * tf.where( tf.abs(x) < self.threshold, tf.zeros_like(x), tf.sign(x)))
def _ceil_through(x): """Computes the ceiling operation using straight through estimator.""" return x + tf.stop_gradient(-x + tf.ceil(x))
def _fn(x): # We'll make the gradient be `1` regardless of input. return f_x + (x - tf.stop_gradient(x))
def contrastive_loss(features, labels=None, temperature=1.0, contrast_mode=enums.LossContrastMode.ALL_VIEWS, summation_location=enums.LossSummationLocation.OUTSIDE, denominator_mode=enums.LossDenominatorMode.ALL, positives_cap=-1, scale_by_temperature=True): r"""Contrastive loss over features. Implemented as described in: https://arxiv.org/abs/2004.11362, Equation 2. Given `num_views` different views of each of `batch_size` samples, let `f_i` (i \in [1, 2 ... (num_views * batch_size)]) denote each respective feature vector. The contrastive loss then takes the following form: L = \sum_{i} L_i where each L_i is computed as: L_i = -\tau * \sum_{k \in P(i)} \log(p_{ik}) (1) where P(i) is the set of positives for entry i (distinct from i) and where: \exp(f_i^T f_k / \tau) p_{ik} = ---------------------------------------- (2) \sum_{j \in A(i)} \exp(f_i^T f_j / \tau) where A(i) is the set of all positives or negatives (distinct from i). `i` is the anchor, and \tau is the temperature. This maximizes the likelihood of a given (anchor, positive) pair with respect to all possible pairs where the first member is the anchor and the second member is a positive or a negative. A typical way to define a positive is to define samples from the same class (but not the anchor itself) regardless of what view they are from. Similarly, a typical way to define a negative is for it to be any view of a sample from a different class. There are two ways to define which feature pairs should be treated as positives and negatives. All views of the same sample are always treated as positives. You can declare other samples to be positives by providing `labels` such that all samples with the same label will be positives for each other. If `labels` is not provided then we default to every sample belonging to its own unique class. Therefore, the only positive used is another view of the anchor itself. This implements the loss as described in: https://arxiv.org/pdf/2002.05709.pdf A Simple Framework for Contrastive Learning of Visual Representations Chen T., Kornblith S., Norouzi M., Hinton G. It is recommended to use features whose L_2 norm is 1. since that ensures that the loss does not return NaN values without changing the intended behaviour of the loss function. In (1) above, note that the summation over positives is located outside of the \log(). However, one can permute these two operations. The result is Eq. 3 in https://arxiv.org/abs/2004.11362. Users can specify the location of the summation relative to the \log() via the `summation_location' argmument: - 'out': Eq. 2 in https://arxiv.org/abs/2004.11362. - 'in' : Eq. 3 in https://arxiv.org/abs/2004.11362. Additionally, in (2) above, note that the denominator sums over *all* entries distinct from i. One can change which terms are included in the denominator via the `denominator_mode` argument: - LossDenominatorMode.ALL : All entries (i.e., all negatives and all positives) distinct from i are included. - LossDenominatorMode.ONE_POSITIVE : All negatives are included but only the single positive in the numerator of (2) is included. Any other positives are excluded. - LossDenominatorMode.ONLY_NEGATIVES: All negatives are included but no positives are, not even the single positive in the numerator of (2). On TPUs, this method will internally perform the cross-replica operations that enable using the samples from all cores in computing the loss. The inputs to this function should be the features and labels from a single core and each core will compute the loss using just these features as anchors, but will use positives and negatives from the full global batch. Since the loss for each anchor is only computed on one TPU core, it's still necessary to have a cross-replica reduction in the final loss computation. Also, though it is not applicable to multiview contrastive learning, this function will work if |features| contains only 1 view. In the high batch size limit, the implemented contrastive loss with only 1 view, positives_cap = 1, and temperature = 1.0 is equivalent to the N-pairs loss (https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective.pdf) Args: features: A Tensor of rank at least 3, where the first 2 dimensions are batch_size and num_views, and the remaining dimensions are the feature shape. Note that when running on TPU, batch_size is the per-core batch size. labels: One-hot labels to be used to construct the supervised contrastive loss. Samples with the same labels are used as positives for each other. Labels must have shape [batch_size, num_labels] with numeric dtype and be 0-1 valued. Note that when running on TPU, batch_size is the per-core batch size. temperature: Temperature at which softmax evaluation is done. Temperature must be a python scalar or scalar Tensor of numeric dtype. contrast_mode: LossContrastMode specifying which views get used as anchors (f_i in the expression above) 'ALL_VIEWS': All the views of all samples are used as anchors (f_i in the expression above). 'ONE_VIEW': Just the first view of each sample is used as an anchor (f_i in the expression above). This view is called the `core` view against which other views are contrasted. summation_location: LossSummationLocation specifying location of positives summation. See documentation above for more details. denominator_mode: LossDenominatorMode specifying which positives to include in contrastive denominator. See documentation above for more details. positives_cap: Integer maximum number of positives *other* than augmentations of anchor. Infinite if < 0. Must be multiple of num_views. Including augmentations, a maximum of (positives_cap + num_views - 1) positives is possible. This parameter modifies the contrastive numerator by selecting which positives are present in the summation, and which positives contribure to the denominator if denominator_mode == enums.LossDenominatorMode.ALL. scale_by_temperature: Boolean. Whether to scale the loss by `temperature`. The loss gradient naturally has a 1/temperature scaling factor, so this counteracts it. Returns: Scalar tensor with contrastive loss value with shape [batch_size] and dtype tf.float32. The loss for each batch element is the mean over all views. Raises: ValueError if the shapes of any of the Tensors are unexpected, or if both `labels` and `mask` are not `None`. """ features = tf.convert_to_tensor(features) labels = tf.convert_to_tensor(labels) if labels is not None else None local_batch_size, num_views = _validate_contrastive_loss_inputs( features, labels, contrast_mode, summation_location, denominator_mode, positives_cap) # Flatten `features` to a single dimension per view per sample so it has shape # [local_batch_size, num_views, num_features]. if features.shape.rank > 3: features = tf.reshape( features, tf.concat([tf.shape(features)[:2], [-1]], axis=0), 'flattened_features') if features.dtype != tf.float32: features = tf.cast(features, tf.float32) # Grab the features from all TPU cores. We use the local batch as anchors and # the full global batch as contrastives. If not on TPU, global_features is the # same as features. global_features = utils.cross_replica_concat(features) global_batch_size = tf.compat.dimension_at_index(global_features.shape, 0).value local_replica_id = utils.local_tpu_replica_id() # Generate the [local_batch_size, global_batch_size] slice of the # [global_batch_size, global_batch_size] identity matrix that corresponds to # the current replica. diagonal_mask = tf.one_hot( tf.range(local_batch_size) + (local_replica_id * local_batch_size), global_batch_size) # Generate `mask` with shape [local_batch_size, global_batch_size] that # indicates which samples should be considered positives for each other. if labels is None: # Defaults to every sample belonging to its own unique class, containing # just that sample and other views of it. mask = diagonal_mask else: labels = tf.cast(labels, tf.float32) # TPU matmul op unsupported for ints. global_labels = utils.cross_replica_concat(labels) mask = tf.linalg.matmul(labels, global_labels, transpose_b=True) mask = tf.ensure_shape(mask, [local_batch_size, global_batch_size]) # To streamline the subsequent TF, the first two dimensions of # `global_features` (i.e., global_batch_size and num_views) should be # transposed and then flattened. The result has shape # [num_views * global_batch_size, num_features], and its first dimension # elements are grouped by view, not by sample. all_global_features = tf.reshape( tf.transpose(global_features, perm=[1, 0, 2]), [num_views * global_batch_size, -1]) if contrast_mode == enums.LossContrastMode.ONE_VIEW: anchor_features = features[:, 0] num_anchor_views = 1 else: # contrast_mode == enums.LossContrastMode.ALL_VIEWS # Reshape features to match how global_features is reshaped above. anchor_features = tf.reshape(tf.transpose(features, perm=[1, 0, 2]), [num_views * local_batch_size, -1]) num_anchor_views = num_views # Generate `logits`, the tensor of (temperature-scaled) dot products of the # anchor features with all features. It has shape # [local_batch_size * num_anchor_views, global_batch_size * num_views]. To # improve numerical stability, subtract out the largest |logits| element in # each row from all elements in that row. Since |logits| is only ever used as # a ratio of exponentials of |logits| values, this subtraction does not change # the results correctness. A stop_gradient() is needed because this change is # just for numerical precision. logits = tf.linalg.matmul(anchor_features, all_global_features, transpose_b=True) temperature = tf.cast(temperature, tf.float32) logits = logits / temperature logits = (logits - tf.reduce_max(tf.stop_gradient(logits), axis=1, keepdims=True)) exp_logits = tf.exp(logits) # The following masks are all tiled by the number of views, i.e., they have # shape [local_batch_size * num_anchor_views, global_batch_size * num_views]. positives_mask, negatives_mask = (_create_tiled_masks( mask, diagonal_mask, num_views, num_anchor_views, positives_cap)) num_positives_per_row = tf.reduce_sum(positives_mask, axis=1) if denominator_mode == enums.LossDenominatorMode.ALL: denominator = tf.reduce_sum( exp_logits * negatives_mask, axis=1, keepdims=True) + tf.reduce_sum( exp_logits * positives_mask, axis=1, keepdims=True) elif denominator_mode == enums.LossDenominatorMode.ONE_POSITIVE: denominator = exp_logits + tf.reduce_sum( exp_logits * negatives_mask, axis=1, keepdims=True) else: # denominator_mode == enums.LossDenominatorMode.ONLY_NEGATIVES denominator = tf.reduce_sum(exp_logits * negatives_mask, axis=1, keepdims=True) # Note that num_positives_per_row can be zero only if 1 view is used. The # various tf.math.divide_no_nan() calls below are to handle this case. if summation_location == enums.LossSummationLocation.OUTSIDE: log_probs = (logits - tf.math.log(denominator)) * positives_mask log_probs = tf.reduce_sum(log_probs, axis=1) log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row) else: # summation_location == enums.LossSummationLocation.INSIDE log_probs = exp_logits / denominator * positives_mask log_probs = tf.reduce_sum(log_probs, axis=1) log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row) log_probs = tf.math.log(log_probs) loss = -log_probs if scale_by_temperature: loss *= temperature loss = tf.reshape(loss, [num_anchor_views, local_batch_size]) if num_views != 1: loss = tf.reduce_mean(loss, axis=0) else: # The 1 view case requires special handling bc, unlike in the > 1 view case, # not all samples are guaranteed to have a positive. Also, no reduction over # views is needed. num_valid_views_per_sample = (tf.reshape(num_positives_per_row, [1, local_batch_size])) loss = tf.squeeze( tf.math.divide_no_nan(loss, num_valid_views_per_sample)) return loss
def fn(x, y): return x**2 + tf.stop_gradient(y)**2
def _center_previous_state(x): # The empirical mean here is a stand-in for the true mean, so we drop the # gradient that flows through this term. x_mean = _reduce_mean_with_axes(x, batch_axes, reduce_chain_axis_names) return x - tf.stop_gradient(x_mean)
def one_step(self, current_state, previous_kernel_results): with tf.name_scope( mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random.normal( shape=tf.shape(input=x), dtype=self._momentum_dtype or x.dtype.base_dtype, seed=self._seed_stream())) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [tf.stop_gradient(x) for x in next_state_parts] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ) return maybe_flatten(next_state_parts), new_kernel_results
def regression_loss(logits, labels, num_steps, steps, seq_lens, loss_type, normalize_indices, variance_lambda, huber_delta): """Loss function based on regressing to the correct indices. In the paper, this is called Cycle-back Regression. There are 3 variants of this loss: i) regression_mse: MSE of the predicted indices and ground truth indices. ii) regression_mse_var: MSE of the predicted indices that takes into account the variance of the similarities. This is important when the rate at which sequences go through different phases changes a lot. The variance scaling allows dynamic weighting of the MSE loss based on the similarities. iii) regression_huber: Huber loss between the predicted indices and ground truth indices. Args: logits: Tensor, Pre-softmax similarity scores after cycling back to the starting sequence. labels: Tensor, One hot labels containing the ground truth. The index where the cycle started is 1. num_steps: Integer, Number of steps in the sequence embeddings. steps: Tensor, step indices/frame indices of the embeddings of the shape [N, T] where N is the batch size, T is the number of the timesteps. seq_lens: Tensor, Lengths of the sequences from which the sampling was done. This can provide additional temporal information to the alignment loss. loss_type: String, This specifies the kind of regression loss function. Currently supported loss functions: regression_mse, regression_mse_var, regression_huber. normalize_indices: Boolean, If True, normalizes indices by sequence lengths. Useful for ensuring numerical instabilities don't arise as sequence indices can be large numbers. variance_lambda: Float, Weight of the variance of the similarity predictions while cycling back. If this is high then the low variance similarities are preferred by the loss while making this term low results in high variance of the similarities (more uniform/random matching). huber_delta: float, Huber delta described in tf.keras.losses.huber_loss. Returns: loss: Tensor, A scalar loss calculated using a variant of regression. """ # Just to be safe, we stop gradients from labels as we are generating labels. labels = tf.stop_gradient(labels) steps = tf.stop_gradient(steps) if normalize_indices: float_seq_lens = tf.cast(seq_lens, tf.float32) tile_seq_lens = tf.tile( tf.expand_dims(float_seq_lens, axis=1), [1, num_steps]) steps = tf.cast(steps, tf.float32) / tile_seq_lens else: steps = tf.cast(steps, tf.float32) beta = tf.nn.softmax(logits) true_time = tf.reduce_sum(steps * labels, axis=1) pred_time = tf.reduce_sum(steps * beta, axis=1) if loss_type in ['regression_mse', 'regression_mse_var']: if 'var' in loss_type: # Variance aware regression. pred_time_tiled = tf.tile(tf.expand_dims(pred_time, axis=1), [1, num_steps]) pred_time_variance = tf.reduce_sum( tf.square(steps - pred_time_tiled) * beta, axis=1) # Using log of variance as it is numerically stabler. pred_time_log_var = tf.math.log(pred_time_variance) squared_error = tf.square(true_time - pred_time) return tf.reduce_mean(tf.math.exp(-pred_time_log_var) * squared_error + variance_lambda * pred_time_log_var) else: return tf.reduce_mean( tf.keras.losses.mean_squared_error(y_true=true_time, y_pred=pred_time)) elif loss_type == 'regression_huber': return tf.reduce_mean(tf.keras.losses.huber_loss( y_true=true_time, y_pred=pred_time, delta=huber_delta)) else: raise ValueError('Unsupported regression loss %s. Supported losses are: ' 'regression_mse, regresstion_mse_var and regression_huber.' % loss_type)
def testDistribution(self, dist_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return seed = tfp_test_util.test_seed() dist = data.draw(distributions(dist_name=dist_name, enable_vars=True)) batch_shape = dist.batch_shape batch_shape2 = data.draw( tfp_hps.broadcast_compatible_shape(batch_shape)) dist2 = data.draw( distributions(dist_name=dist_name, batch_shape=batch_shape2, event_dim=get_event_dim(dist), enable_vars=True)) self.evaluate([var.initializer for var in dist.variables]) # Check that the distribution passes Variables through to the accessor # properties (without converting them to Tensor or anything like that). for k, v in six.iteritems(dist.parameters): if not tensor_util.is_ref(v): continue self.assertIs(getattr(dist, k), v) # Check that standard statistics do not read distribution parameters more # than twice (once in the stat itself and up to once in any validation # assertions). for stat in data.draw( hps.sets(hps.one_of( map(hps.just, [ 'covariance', 'entropy', 'mean', 'mode', 'stddev', 'variance' ])), min_size=3, max_size=3)): hp.note('Testing excessive var usage in {}.{}'.format( dist_name, stat)) try: with tfp_hps.assert_no_excessive_var_usage( 'statistic `{}` of `{}`'.format(stat, dist)): getattr(dist, stat)() except NotImplementedError: pass # Check that `sample` doesn't read distribution parameters more than twice, # and that it produces non-None gradients (if the distribution is fully # reparameterized). with tf.GradientTape() as tape: # TDs do bijector assertions twice (once by distribution.sample, and once # by bijector.forward). max_permissible = (3 if isinstance( dist, tfd.TransformedDistribution) else 2) with tfp_hps.assert_no_excessive_var_usage( 'method `sample` of `{}`'.format(dist), max_permissible=max_permissible): sample = dist.sample(seed=seed) if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, dist.variables) for grad, var in zip(grads, dist.variables): var_name = var.name.rstrip('_0123456789:') if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()): continue if grad is None: raise AssertionError( 'Missing sample -> {} grad for distribution {}'.format( var_name, dist_name)) # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's # own samples. Also, to relax conversion counts for KL (might do >2 w/ # validate_args). dist = dist.copy(validate_args=False) dist2 = dist2.copy(validate_args=False) # Test that KL divergence reads distribution parameters at most once, and # that is produces non-None gradients. try: for d1, d2 in (dist, dist2), (dist2, dist): with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))' .format(d1, d1.variables, d2, d2.variables), max_permissible=1 ): # No validation => 1 convert per var. kl = d1.kl_divergence(d2) wrt_vars = list(d1.variables) + list(d2.variables) grads = tape.gradient(kl, wrt_vars) for grad, var in zip(grads, wrt_vars): if grad is None and dist_name not in NO_KL_PARAM_GRADS: raise AssertionError( 'Missing KL({} || {}) -> {} grad:\n' '{} vars: {}\n{} vars: {}'.format( d1, d2, var, d1, d1.variables, d2, d2.variables)) except NotImplementedError: pass # Test that log_prob produces non-None gradients, except for distributions # on the NO_LOG_PROB_PARAM_GRADS blacklist. if dist_name not in NO_LOG_PROB_PARAM_GRADS: with tf.GradientTape() as tape: lp = dist.log_prob(tf.stop_gradient(sample)) grads = tape.gradient(lp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for distribution {}'. format(var, dist_name)) # Test that all forms of probability evaluation avoid reading distribution # parameters more than once. for evaluative in data.draw( hps.sets(hps.one_of( map(hps.just, [ 'log_prob', 'prob', 'log_cdf', 'cdf', 'log_survival_function', 'survival_function' ])), min_size=3, max_size=3)): hp.note('Testing excessive var usage in {}.{}'.format( dist_name, evaluative)) try: # No validation => 1 convert. But for TD we allow 2: # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp) max_permissible = (2 if isinstance( dist, tfd.TransformedDistribution) else 1) with tfp_hps.assert_no_excessive_var_usage( 'evaluative `{}` of `{}`'.format(evaluative, dist), max_permissible=max_permissible): getattr(dist, evaluative)(sample) except NotImplementedError: pass
def call(self, x, training=False): x_flat = tf.reshape(x, shape=(-1, self.depth)) # Split each input vector into one segment per head. x_flat_split = tf.split(x_flat, self.num_heads, axis=1) x_flat = tf.concat(x_flat_split, axis=0) if training: # Figure out which centroids we want to keep, and which we want to # restart. n = x_flat.shape[0] keep = self.counts * self.k > self.restart_threshold * n restart = tf.math.logical_not(keep) # Replace centroids to restart with elements from the batch, using samples # from a uniform distribution as a fallback in case we need to restart # more centroids than we have elements in the batch. restart_idx = tf.squeeze(tf.where(restart), -1) n_replace = tf.minimum(tf.shape(restart_idx)[0], x_flat.shape[0]) e_restart = tf.tensor_scatter_nd_update( tf.random.uniform([self.k, self.depth // self.num_heads]), tf.expand_dims(restart_idx[:n_replace], 1), tf.random.shuffle(x_flat)[:n_replace] ) # Compute the values of the centroids we want to keep by dividing the # summed vectors by the corresponding counts. e = tf.where( tf.expand_dims(keep, 1), tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)), e_restart ) else: # If not training, just use the centroids as is with no restarts. e = tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)) # Compute distance between each input vector and each cluster center. distances = ( tf.expand_dims(tf.reduce_sum(x_flat**2, axis=1), 1) - 2 * tf.matmul(x_flat, tf.transpose(e)) + tf.expand_dims(tf.reduce_sum(e**2, axis=1), 0) ) # Find nearest cluster center for each input vector. c = tf.argmin(distances, axis=1) # Quantize input vectors with straight-through estimator. z = tf.nn.embedding_lookup(e, c) z_split = tf.split(z, self.num_heads, axis=0) z = tf.concat(z_split, axis=1) z = tf.reshape(z, tf.shape(x)) z = x + tf.stop_gradient(z - x) if training: # Compute cluster counts and vector sums over the batch. oh = tf.one_hot(indices=c, depth=self.k) counts = tf.reduce_sum(oh, axis=0) sums = tf.matmul(oh, x_flat, transpose_a=True) # Apply exponential moving average to cluster counts and vector sums. self.counts.assign_sub((1 - self.gamma) * (self.counts - counts)) self.sums.assign_sub((1 - self.gamma) * (self.sums - sums)) c_split = tf.split(c, self.num_heads, axis=0) c = tf.stack(c_split, axis=1) c = tf.reshape(c, tf.concat([tf.shape(x)[:-1], [self.num_heads]], axis=0)) return z, c
def _reparameterize_sample(self, x, event_shape): """Adds reparameterization (pathwise) gradients to samples of the mixture. Implicit reparameterization gradients are dx/dphi = -(d transform(x, phi) / dx)^-1 * d transform(x, phi) / dphi, where transform(x, phi) is distributional transform that removes all parameters from samples x. We implement them by replacing x with -stop_gradient(d transform(x, phi) / dx)^-1 * transform(x, phi)] for the backward pass (gradient computation). The derivative of this quantity w.r.t. phi is then the implicit reparameterization gradient. Note that this replaces the gradients w.r.t. both the mixture distribution parameters and components distributions parameters. Limitations: 1. Fundamental: components must be fully reparameterized. 2. Distributional transform is currently only implemented for factorized components. 3. Distributional transform currently only works for known rank of the batch tensor. Args: x: Sample of mixture distribution event_shape: The event shape of this distribution Returns: Tensor with same value as x, but with reparameterization gradients """ # Remove the existing gradients of x wrt parameters of the components. x = tf.stop_gradient(x) event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32) x_2d_shape = [-1, event_size] # [S*prod(B), prod(E)] # Perform distributional transform of x in [S, B, E] shape, # but have Jacobian of size [S*prod(B), prod(E), prod(E)]. def reshaped_distributional_transform(x_2d): return tf.reshape( self._distributional_transform(tf.reshape(x_2d, ps.shape(x)), event_shape), x_2d_shape) # transform_2d: [S*prod(B), prod(E)] # jacobian: [S*prod(B), prod(E), prod(E)] x_2d = tf.reshape(x, x_2d_shape) transform_2d, jacobian = value_and_batch_jacobian( reshaped_distributional_transform, x_2d) # We only provide the first derivative; the second derivative computed by # autodiff would be incorrect, so we raise an error if it is requested. transform_2d = _prevent_2nd_derivative(transform_2d) # Compute [- stop_gradient(jacobian)^-1 * transform] by solving a linear # system. The Jacobian is lower triangular because the distributional # transform for i-th event dimension does not depend on the next # dimensions. surrogate_x_2d = -tf.linalg.triangular_solve( tf.stop_gradient(jacobian), transform_2d[..., tf.newaxis], lower=True) # [S*prod(B), prod(E), 1] surrogate_x = tf.reshape(surrogate_x_2d, ps.shape(x)) # Replace gradients of x with gradients of surrogate_x, but keep the value. return x + (surrogate_x - tf.stop_gradient(surrogate_x))
def forward(self, features, training=True): """Run forward pass of model (no losses) on a dictionary of features.""" # Audio -> Sinusoids ------------------------------------------------------- audio = features['audio'] # Encode the data from audio to sinusoids. pg_in = self.sinusoidal_encoder(features, training=training) # Manually apply the scaling nonlinearities. sin_freqs = self.freq_scale_fn(pg_in['frequencies']) sin_amps = self.amps_scale_fn(pg_in['amplitudes']) noise_magnitudes = self.amps_scale_fn(pg_in['noise_magnitudes']) pg_in['frequencies'] = sin_freqs pg_in['amplitudes'] = sin_amps pg_in['noise_magnitudes'] = noise_magnitudes # Reconstruct sinusoidal audio. sin_audio = self.processor_group(pg_in) outputs = { # Input signal. 'audio': audio, # Filtered noise signal. 'noise_magnitudes': noise_magnitudes, # Sinusoidal signal. 'sin_audio': sin_audio, 'sin_amps': sin_amps, 'sin_freqs': sin_freqs, } # Sinusoids -> Harmonics --------------------------------------------------- # Encode the sinusoids into a harmonics. if self.stop_gradient: sin_freqs = tf.stop_gradient(sin_freqs) sin_amps = tf.stop_gradient(sin_amps) noise_magnitudes = tf.stop_gradient(noise_magnitudes) if self.harmonic_encoder is not None: harm_amp, harm_dist, f0_hz = self.harmonic_encoder( sin_freqs, sin_amps) # Decode harmonics back to sinusoids. n_harmonics = int(harm_dist.shape[-1]) harm_freqs = ddsp.core.get_harmonic_frequencies(f0_hz, n_harmonics) harm_amps = harm_amp * harm_dist # Reconstruct harmonic audio. pg_in['frequencies'] = harm_freqs pg_in['amplitudes'] = harm_amps pg_in['noise_magnitudes'] = noise_magnitudes harm_audio = self.processor_group(pg_in) outputs.update({ # Harmonic signal. 'harm_audio': harm_audio, 'harm_amp': harm_amp, 'harm_dist': harm_dist, 'f0_hz': f0_hz, # Harmonic Sinusoids. 'harm_freqs': harm_freqs, 'harm_amps': harm_amps, }) return outputs