def BeamSearch(self, x, y, decoder_reduce_sum=False): for i in range(3): with tf.name_scope('encoder%03d' % i): x = tf.identity(x) y = tf.identity(y) x = x + 1 tpu_summary.scalar('x_mean', tf.reduce_mean(x)) tpu_summary.scalar('y_mean', tf.reduce_mean(y)) def DecoderStep(x, y): for i in range(3): with tf.name_scope('decoder%03d' % i): x = tf.identity(x) y = tf.identity(y) y = y + 1 if decoder_reduce_sum: tpu_summary.scalar( 'x_mean', tf.reduce_mean(x), while_loop_reduce='sum') tpu_summary.scalar( 'y_mean', tf.reduce_mean(y), while_loop_reduce='sum') else: tpu_summary.scalar('x_mean', tf.reduce_mean(x)) tpu_summary.scalar('y_mean', tf.reduce_mean(y)) return x, y def DecoderCond(x, y): del x, y return True (x, y) = tf.while_loop( cond=DecoderCond, body=DecoderStep, loop_vars=(x, y), maximum_iterations=10) return x, y
def DecoderStep(x, y): for i in range(3): with tf.name_scope('decoder%03d' % i): x = tf.identity(x) y = tf.identity(y) y = y + 1 if decoder_reduce_sum: tpu_summary.scalar( 'x_mean', tf.reduce_mean(x), while_loop_reduce='sum') tpu_summary.scalar( 'y_mean', tf.reduce_mean(y), while_loop_reduce='sum') else: tpu_summary.scalar('x_mean', tf.reduce_mean(x)) tpu_summary.scalar('y_mean', tf.reduce_mean(y)) return x, y
def _AveNeigh(self, spellings, pronunciations, theta, batch_size): p = self.params spellings = tf.reshape(spellings, (batch_size * p.max_neighbors, -1)) pronunciations = tf.reshape(pronunciations, (batch_size * p.max_neighbors, -1)) spell_inp = py_utils.NestedMap({ "ids": spellings, "paddings": self._GetPaddings(spellings, dtype=tf.int32), }) pron_inp = py_utils.NestedMap({ "ids": pronunciations, "paddings": self._GetPaddings(pronunciations, dtype=tf.int32), }) if p.use_neigh_id_emb: # Add the same ID embeddings to both spelling and pron so that the # model knows how they pair up. neigh_ids = tf.range(p.max_neighbors)[:, tf.newaxis] spell_inp["task_ids"] = tf.tile(neigh_ids, [batch_size, p.max_spelling_len]) pron_inp["task_ids"] = tf.tile( neigh_ids, [batch_size, p.max_pronunciation_len]) spell_enc_out = self.spell_encoder.FProp(theta.spell_encoder, spell_inp) pron_enc_out = self.spell_encoder.FProp(theta.pron_encoder, pron_inp) spell_enc = tf.reshape( spell_enc_out["encoded"], (p.max_spelling_len, batch_size, p.max_neighbors, p.enc_units)) spell_enc = tf.reduce_mean(spell_enc, axis=0) pron_enc = tf.reshape(pron_enc_out["encoded"], (p.max_pronunciation_len, batch_size, p.max_neighbors, p.enc_units)) pron_enc = tf.reduce_mean(pron_enc, axis=0) spell_enc = tf.transpose(spell_enc, (1, 0, 2)) pron_enc = tf.transpose(pron_enc, (1, 0, 2)) padding = tf.zeros((p.max_neighbors, batch_size)) return [spell_enc, pron_enc], [padding, padding]
def FProp(self, x, y): for i in range(3): with tf.name_scope('encoder%03d' % i): x = tf.identity(x) y = tf.identity(y) x = x + 1 tpu_summary.scalar('x_mean', tf.reduce_mean(x)) tpu_summary.scalar('y_mean', tf.reduce_mean(y)) for i in range(3): with tf.name_scope('decoder%03d' % i): x = tf.identity(x) y = tf.identity(y) y = y + 1 tpu_summary.scalar('x_mean', tf.reduce_mean(x)) tpu_summary.scalar('y_mean', tf.reduce_mean(y)) return x, y
def AddSummary(self, lr, optimizer, var_grad): summary_utils.scalar('adagrad_lr', lr) for v, _ in var_grad.Flatten(): slot = optimizer.get_slot(v, 'accumulator') assert slot is not None summary_utils.scalar('optimizer/adagrad_accum_%s' % v.name, tf.reduce_mean(slot))
def ComputePredictions(self, theta, input_batch): # Forward through layers. act = self.extract.FProp(theta.extract, input_batch.data) # Avg pool act = tf.reduce_mean(act, axis=[1, 2]) logits = self.softmax.Logits(theta.softmax, act) return py_utils.NestedMap(act=act, logits=logits)
def _ConcatAveNeigh(self, spellings, pronunciations, theta, batch_size): p = self.params spellings = tf.reshape(spellings, (batch_size * p.max_neighbors, -1)) pronunciations = tf.reshape(pronunciations, (batch_size * p.max_neighbors, -1)) # ->(batch_size * max_neighbors, max_spelling_len + max_pronunciation_len) neigh_info = tf.concat([spellings, pronunciations], axis=1) # TODO(llion): Add task ids to concatenated info? if p.use_neigh_id_emb: raise NotImplementedError() neigh_inp = py_utils.NestedMap({ "ids": neigh_info, "paddings": self._GetPaddings(neigh_info, dtype=tf.int32), }) neigh_out = self.spell_encoder.FProp(theta.spell_encoder, neigh_inp) neigh_enc = tf.reshape(neigh_out["encoded"], (p.max_spelling_len + p.max_pronunciation_len, batch_size, p.max_neighbors, p.enc_units)) neigh_enc = tf.reduce_mean(neigh_enc, axis=0) neigh_enc = tf.transpose(neigh_enc, (1, 0, 2)) padding = tf.zeros((p.max_neighbors, batch_size)) return [neigh_enc], [padding]
def _verify_timestep_counts(self, num_splits): num_micro_batches = 8 batch_size = 16 with self.session(graph=tf.Graph()) as sess: tf.set_random_seed(1245) inputs = tf.random_uniform([batch_size, 8, 8, 1], seed=12345) net = _BuildDummyPipelineCnn(num_splits=num_splits, num_micro_batches=num_micro_batches) endpoints = net.FPropDefaultTheta(inputs) if isinstance(endpoints, (list, tuple)): logits, aux_logits = endpoints else: logits = endpoints aux_logits = None loss = tf.reduce_mean(logits) grads = tf.gradients(loss, tf.trainable_variables()) grad_norm = tf.sqrt(py_utils.SumSquared(grads)) ts = net.GetAccumulatorValues().Flatten() sess.run(tf.global_variables_initializer()) grad_norm_val, ts_vals = sess.run([grad_norm, ts]) test_utils.CompareToGoldenSingleFloat(self, 0.268087, grad_norm_val) # Accumulator values should be equal to number of time steps in pipeline. for ts_val in list(ts_vals): expected_ts = num_micro_batches if num_splits > 1 else 1 self.assertEqual(ts_val, expected_ts) if aux_logits is not None: aux_logit_tensor = sess.run(aux_logits) self.assertEqual(aux_logit_tensor.shape, (batch_size, 8, 8, 1))
def testCheckNumerics(self): checked = py_utils.CheckNumerics( tf.convert_to_tensor([2.0, 3.0], tf.float32)) self.assertListEqual([2.0, 3.0], checked.numpy().tolist()) with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 'NaN'): py_utils.CheckNumerics( tf.reduce_mean(tf.convert_to_tensor([], tf.float32)))
def FProp(self, theta, input_batch): # pyformat: disable """Compute features for the pillars and convert them back to a dense grid. Args: theta: A `.NestedMap` object containing variable values of this task. input_batch: A `.NestedMap` object containing input tensors. Following keys are required: - grid_num_points: Integer tensor with shape [batch size, nx, ny, nz], where nx, ny, nz corresponds to the grid sizes (i.e., number of voxels in each axis dimension). - pillar_points: Float tensor with shape [batch size, num_pillars, num_points_per_pillar, 3 + num_laser_features] - pillar_centers: Float tensor with shape [batch size, num_pillars, num_points_per_pillar, 3] - pillar_locations: Float tensor with shape [batch size, num_pillars, 3] Returns: The dense features with shape [b, nx, ny, nz * fdims]. """ # pyformat: enable p = self.params bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4) # Process points to concatenate a set of fixed features (e.g., # add means, centers, normalize points to means). num_features = 3 + p.num_laser_features pillar_points = py_utils.HasShape(input_batch.pillar_points, [bs, -1, -1, num_features]) _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4) pillar_xyz = pillar_points[..., :3] pillar_means = tf.reduce_mean(pillar_xyz, axis=2, keep_dims=True) pillar_feats = pillar_points[..., 3:] pillar_centers = py_utils.HasShape(input_batch.pillar_centers, [bs, -1, 1, 3]) pillar_concat = tf.concat(axis=3, values=[ pillar_xyz - pillar_means, pillar_feats, tf.tile(pillar_means, [1, 1, npoints, 1]), tf.tile(pillar_centers, [1, 1, npoints, 1]) ]) # Featurize pillars. pillar_features = self.featurizer.FProp(theta.featurizer, pillar_concat) # Convert back to the dense grid. pillar_locations = py_utils.HasShape(input_batch.pillar_locations, [bs, npillars, 3]) dense_features = SparseToDense(grid_shape=(nx, ny, nz), locations=pillar_locations, feats=pillar_features) return dense_features
def _GetExpertDist(self, theta, inputs, *args): """Get the task id from inputs tensors.""" # TODO(huangyp): support the more general case when batch size is not 1. # Input shape can be either [batch, length, dim] or [length, batch, dim] reshaped_inputs = tf.reshape(inputs, [-1, self.params.cond_dim]) if self.params.nonzeros_mean: per_example_emb = tf.reduce_sum(reshaped_inputs, 0) nonzeros = tf.cast(tf.count_nonzero(reshaped_inputs, 0), dtype=tf.float32) per_example_emb /= (nonzeros + 1e-10) else: per_example_emb = tf.reduce_mean(reshaped_inputs, 0) expert_dist = tf.nn.sigmoid(tf.einsum('i,ij->j', per_example_emb, theta.w)) return expert_dist
def _AddAttentionSummaries(self, name, atten_probs): # Plots attention prob summaries for joint network. # TODO(ankurbpn): Check why op wasn't compiling on TPUs. p = self.params if not py_utils.use_tpu( ) and p.allow_attention_summaries and atten_probs is not None: atten_shape = tf.shape(atten_probs) atten_probs = tf.reshape( atten_probs, [atten_shape[0], atten_shape[1], -1, atten_shape[-1]]) # Only plots first example of the batch. atten_probs = tf.reduce_mean(atten_probs[0:1, :, :, :], 1) self._AddAttenProbsImageSummary(name, atten_probs) self._AddAttenProbsHistogramSummary(name, atten_probs)
def ComputeLoss(self, theta, predictions, input_batch): output_batch = predictions ctc_loss = tf.nn.ctc_loss( input_batch.tgt.labels, output_batch.encoded, py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1), py_utils.LengthsFromBitMask(output_batch.padding, 0), logits_time_major=True, blank_index=self.params.blank_index) # ctc_loss.shape = (B) total_loss = tf.reduce_mean(ctc_loss) per_sequence_loss = {'loss': ctc_loss} return dict(loss=(total_loss, 1.0)), per_sequence_loss
def get_accuracy(self, loss, pred, target): p = self.params int_dtype = pred.dtype target = tf.cast(target, int_dtype) pad_id = int(p.input.feature_neighborhood_input.batch_opts.pad_value) mask = tf.cast(tf.math.not_equal(target, pad_id), int_dtype) pred *= mask num_non_zero = tf.cast(tf.reduce_sum(mask), tf.float32) equal = tf.math.equal(pred, target) loss["accuracy_per_example"] = (tf.reduce_mean( tf.cast(tf.reduce_all(equal, axis=1), tf.float32)), p.input.batch_size) equal = tf.cast(equal, tf.float32) equal *= tf.cast(mask, tf.float32) loss["accuracy_per_char"] = (tf.reduce_sum(equal) / num_non_zero, p.input.batch_size)
def _verify_timestep_counts(self, num_splits, auto_partition=False, micro_batch_size=None): num_micro_batches = 8 batch_size = 16 with self.session(graph=tf.Graph()) as sess: tf.random.set_seed(1245) inputs = tf.random.uniform([batch_size, 8, 8, 1], seed=12345) if auto_partition: layers = [ _SimpyLayer.Params().Set(name='layer_{}'.format(i)) for i in range(16) ] net = PipeliningLayer.Params().Set( name='pipeline', num_micro_batches=num_micro_batches, cell_tpl=_Partition(layers, num_splits, tshape.Shape([batch_size, 8, 8, 1]))).Instantiate() else: net = _BuildDummyPipelineCnn( num_splits=num_splits, micro_batch_size=micro_batch_size, num_micro_batches=num_micro_batches) endpoints = net.FPropDefaultTheta(inputs) if isinstance(endpoints, (list, tuple)): logits, aux_logits = endpoints else: logits = endpoints aux_logits = None loss = tf.reduce_mean(logits) grads = tf.gradients(loss, tf.trainable_variables()) grad_norm = tf.sqrt(py_utils.SumSquared(grads)) ts = net.GetAccumulatorValues().Flatten() sess.run(tf.global_variables_initializer()) grad_norm_val, ts_vals = sess.run([grad_norm, ts]) test_utils.CompareToGoldenSingleFloat(self, 0.268087, grad_norm_val) # Accumulator values should be equal to number of time steps in pipeline. for ts_val in list(ts_vals): expected_ts = num_micro_batches if num_splits > 1 else 1 self.assertEqual(ts_val, expected_ts) if aux_logits is not None: aux_logit_tensor = sess.run(aux_logits) self.assertEqual(aux_logit_tensor.shape, (batch_size, 8, 8, 1))
def AddAttentionSummaryBatchMajor(attention_tensors, src_paddings, tgt_paddings, transcripts=None, max_outputs=3): """Adds an image summary showing the attention probability matrix and state. As opposed to AddAttentionSummary() takes all tensors with batch dimension in axis 0. Args: attention_tensors: A list of 3D tensors shaped [batch_size, target_len, source_len] where attention[b, i, j] is the probability for the i-th output attending to the j-th input for element b in the batch. src_paddings: A tensor of binary paddings shaped [batch, source_len] for the source sequence. tgt_paddings: A tensor of binary paddings shaped [batch, target_len] for the target sequence. transcripts: Optional, transcripts shaped [batch, source_len] for the source sequence. max_outputs: Integer maximum number of elements of the batch to plot. """ name = attention_tensors[0].name + '/Attention' if not _ShouldAddSummary(): return with plot.MatplotlibFigureSummary(name, max_outputs=max_outputs) as fig: src_lens = SequenceLength(src_paddings) tgt_lens = SequenceLength(tgt_paddings) for n, atten in enumerate(attention_tensors): # Diagnostic metric that decreases as attention picks up. max_entropy = tf.log(tf.cast(src_lens, tf.float32)) max_entropy = tf.expand_dims(tf.expand_dims(max_entropy, -1), -1) atten_normalized_entropy = -atten * tf.log(atten + 1e-10) / max_entropy scalar('Attention/average_normalized_entropy/%d' % n, tf.reduce_mean(atten_normalized_entropy)) args = [atten, src_lens, tgt_lens] if transcripts is not None and n == 0: args.append(transcripts) fig.AddSubplot(args, TrimPaddingAndPlotAttention, title=atten.name, xlabel='Input', ylabel='Output')
def ProcessPillars(input_batch, num_laser_features, featurizer, theta_featurizer): """Compute features for the pillars and convert them back to a dense grid. Args: input_batch: A `.NestedMap` object containing input tensors. num_laser_features: Number of laser features (excluding pillar features) featurizer: The featurizer layer. theta_featurizer: The weights for featurizer. Returns: The dense features with shape [b, nx, ny, nz * fdims]. """ bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4) # Process points to concatenate a set of fixed features (e.g., # add means, centers, normalize points to means). num_features = 3 + num_laser_features pillar_points = py_utils.HasShape(input_batch.pillar_points, [bs, -1, -1, num_features]) _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4) pillar_xyz = pillar_points[..., :3] pillar_means = tf.reduce_mean(pillar_xyz, axis=2, keep_dims=True) pillar_feats = pillar_points[..., 3:] pillar_centers = py_utils.HasShape(input_batch.pillar_centers, [bs, -1, 1, 3]) pillar_concat = tf.concat( axis=3, values=[ pillar_xyz - pillar_means, pillar_feats, tf.tile(pillar_means, [1, 1, npoints, 1]), tf.tile(pillar_centers, [1, 1, npoints, 1]) ]) # Featurize pillars. pillar_features = featurizer.FProp(theta_featurizer, pillar_concat) # Convert back to the dense grid. pillar_locations = py_utils.HasShape(input_batch.pillar_locations, [bs, npillars, 3]) dense_features = SparseToDense( grid_shape=(nx, ny, nz), locations=pillar_locations, feats=pillar_features) return dense_features
def _Gradient(inputs, _, original_grad): # Compute the gradients for each loss w.r.t. the inputs. # TODO(jngiam): Look into whether TF dedups this computation. per_loss_grads = [] for loss, _ in self._losses: per_loss_grad = tf.gradients(loss, self._output_tensor)[0] if per_loss_grad is None: tf.logging.warning( 'Loss %s did not result in a gradient during ' 'GradDrop computation.', loss) else: per_loss_grads.append(per_loss_grad) if not per_loss_grads: raise ValueError('No valid gradients for GradDrop.') # Multiply the gradients with the inputs. grads = per_loss_grads if p.use_input_sign_only: input_abs = tf.abs( tf.cast(tf.abs(inputs) <= p.epsilon, tf.float32) + inputs) grads = [grad * ((inputs) / (input_abs)) for grad in grads] else: grads = [grad * inputs for grad in grads] # Sum gradient over batch, assuming that batch is always on dim 0. if p.marginalize_batch_dim: grads = [ tf.reduce_sum(grad, axis=0, keepdims=True) for grad in grads ] # First discretize all gradients into their sign values. grad_sign_positive = [ tf.cast(grad > 0.0, tf.float32) for grad in grads ] grad_sign_negative = [ tf.cast(grad < 0.0, tf.float32) for grad in grads ] # Calculate the probability of positive gradients based on equation (1) # in the GradDrop paper. grad_abs_sum = tf.add_n([tf.abs(grad) for grad in grads]) prob_pos = (tf.add_n(grads) / (2. * grad_abs_sum + p.epsilon)) # Implementation of different scales for the keep function. Larger # scales result in steeper keep functions. prob_pos *= p.keep_prob_function_scale if p.keep_prob_function == 'sigmoid': # Standard sigmoid has derivative of 0.25 at 0 so the factor of 4.0 # allows the function scale in sigmoid to be compatible with the # function scale in the linear case. prob_pos = tf.sigmoid(4.0 * prob_pos) elif p.keep_prob_function == 'linear': prob_pos += 0.5 # The main, default mode of GradDrop. Only gradients of one sign are kept, # and which sign is calculated via equation (1) of the main paper. prob_pos = tf.cast(prob_pos >= tf.random.uniform(prob_pos.shape), tf.float32) - 0.5 grad_masks = [ (gsp - gsn) * prob_pos >= 0 for (gsn, gsp) in zip(grad_sign_negative, grad_sign_positive) ] # This diag value gives us the percentage of grads which are kept. gradmask_diag = [tf.cast(gm, tf.float32) for gm in grad_masks] diag = tf.reduce_mean(tf.add_n(gradmask_diag) / len(grad_masks)) summary_utils.scalar('average_grad_mask', diag) leak_ratios = [leak_ratio for _, leak_ratio in self._losses] transformed_per_loss_grads = [ grad * (leak + (1.0 - leak) * tf.cast(grad_mask, tf.float32)) for (leak, grad, grad_mask) in zip(leak_ratios, per_loss_grads, grad_masks) ] transformed_grad = tf.cast(tf.add_n(transformed_per_loss_grads), original_grad.dtype) if not p.keep_gradnorm_constant: return transformed_grad transformed_grad_norm = tf.sqrt(tf.reduce_sum(transformed_grad**2)) original_grad_norm = tf.sqrt(tf.reduce_sum(original_grad**2)) return transformed_grad * original_grad_norm / ( transformed_grad_norm + p.epsilon)
def AddAttentionSummaryBatchMajor(attention_tensors, src_paddings, tgt_paddings, transcripts=None, max_outputs=3): """Adds an image summary showing the attention probability matrix and state. As opposed to AddAttentionSummary() takes all tensors with batch dimension in axis 0. Args: attention_tensors: A list of 3D tensors shaped [batch_size, target_len, source_len] where attention[b, i, j] is the probability for the i-th output attending to the j-th input for element b in the batch. src_paddings: A tensor of binary paddings shaped [batch, source_len] for the source sequence. Or a list of tensors of the same length as attention_tensors with a separate paddings for each entry in attention_tensors. tgt_paddings: A tensor of binary paddings shaped [batch, target_len] for the target sequence. Or a list of tensors of the same length as attention_tensors with a separate paddings for each entry in attention_tensors. transcripts: Optional, transcripts shaped [batch, source_len] for the source sequence. max_outputs: Integer maximum number of elements of the batch to plot. """ def VerifyLen(paddings): length = len(paddings) if isinstance(paddings, list) else 1 if length != 1 and length != len(attention_tensors): raise ValueError('Bad length of paddings list {}'.format(length)) VerifyLen(src_paddings) VerifyLen(tgt_paddings) name = attention_tensors[0].name + '/Attention' if not _ShouldAddSummary(): return def ToLengths(paddings): paddings = paddings if isinstance(paddings, list) else [paddings] return [SequenceLength(p) for p in paddings] def Get(lengths, i): return lengths[0 if len(lengths) == 1 else i] src_lens = ToLengths(src_paddings) tgt_lens = ToLengths(tgt_paddings) with plot.MatplotlibFigureSummary(name, max_outputs=max_outputs, gridspec_kwargs={'hspace': 0.3}) as fig: for n, atten in enumerate(attention_tensors): # Diagnostic metric that decreases as attention picks up. max_entropy = tf.log(tf.cast(Get(src_lens, n), tf.float32)) max_entropy = tf.expand_dims(tf.expand_dims(max_entropy, -1), -1) atten_normalized_entropy = -atten * tf.log(atten + 1e-10) / max_entropy scalar('Attention/average_normalized_entropy/%d' % n, tf.reduce_mean(atten_normalized_entropy)) args = [atten, Get(src_lens, n), Get(tgt_lens, n)] if transcripts is not None and n == 0: args.append(transcripts) fig.AddSubplot(args, TrimPaddingAndPlotAttention, title=atten.name, xlabel='Input', ylabel='Output')
def sum_embeddings(t): reshaped = tf.reshape( t, (p.input.batch_size, p.max_neighbors, -1, p.enc_units)) return tf.reduce_mean(reshaped, axis=2)
def ComputeLoss(self, theta, predictions, input_batch): """Computes loss and other metrics for the given predictions. Args: theta: A `.NestedMap` object containing variable values of this task. predictions: The output of `ComputePredictions`. input_batch: A `.NestedMap` object containing input tensors to this tower. Returns: A tuple (metrics, per_example_tensors), where - `metrics` is a dict of str keys to (metric, weight) values - `per_example_tensors` is a dict of str keys to tensors describing each training example, where the first dimension of each tensor is the batch index. """ p = self.params # During TPU training, collect the encodings and ids from all TPUs so the # loss can be computed over all query-result pairs in the global batch. # To avoid duplicating work, each TPU operates on a non-overlapping # slice of these pairs. Specifically, each TPU uses queries drawn from its # local batch and results from the global batch. # Encodings of the local and global examples, keyed by modality. local_flat_encodings = py_utils.NestedMap({ modality: tf.reshape(predictions[modality].encodings, [-1, p.joint_embedding_dim]) for modality in predictions }) global_flat_encodings = tpu_utils.ConcatenateAcrossReplicas( local_flat_encodings) def _ComputePerQueryLoss(query_modality, result_modality): labeler_inputs = label_lib.ExamplePairs.BetweenLocalAndGlobalBatches( input_batch, query_modality=query_modality, result_modality=result_modality) labels = p.label_fn(labeler_inputs) # [num_queries, num_results] flat_similarities = self.score_function( local_flat_encodings[query_modality], global_flat_encodings[result_modality]) flat_labels = tf.reshape(labels, flat_similarities.shape) # [num_queries] return label_lib.MultiLabelContrastiveLoss( labels=flat_labels, logits=flat_similarities) loss_terms = [] metrics = {} for direction, loss_weight in p.loss_weights.items(): query_modality, result_modality = direction if not loss_weight: logging.info('Skipping %s retrieval', direction) continue per_query_losses = _ComputePerQueryLoss(query_modality, result_modality) mean_per_query_loss = tf.reduce_mean(per_query_losses) loss_terms.append(loss_weight * mean_per_query_loss) metrics['loss_{}_to_{}'.format( query_modality, result_modality)] = (mean_per_query_loss, 1) regularization_losses = utils.CollectRegularizationLosses(self) if p.regularization_loss_weight and regularization_losses: tf.logging.info('Adding TF1 regularization loss: %s', regularization_losses) total_reg_loss = tf.reduce_sum(regularization_losses) loss_terms.append(p.regularization_loss_weight * total_reg_loss) metrics['loss_regularization'] = (total_reg_loss, 1) loss = tf.add_n(loss_terms) metrics['loss'] = (loss, 1) return metrics, {}
def try_apply_dense(self, grad, var): assert grad is not None cond = tf.constant(True) is_finite_checks = [] stats = {} grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') def _Upd(c, k, x): stats[k] = x is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x))) return c with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, 'grad_squared', grad_squared) # 0 (factored) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype assert self._multiply_by_parameter_scale lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: parameter_scale = self._parameter_scale(old_val) cond = _Upd(cond, 'parameter_scale', parameter_scale) # 1 (factored) update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, 'new_vr', new_vr) # 2 (factored) cond = _Upd(cond, 'new_vc', new_vc) # 3 (factored) # vr_update = _Wrap(tf.assign, vr, new_vr) # vc_update = _Wrap(tf.assign, vc, new_vc) # updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) mult = tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) cond = _Upd(cond, 'mult', mult) # 4 (factored) x = grad * mult else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, 'new_v', new_v) # v_update = _Wrap(tf.assign, v, new_v) # updates.append(v_update) x = grad * tf.math.rsqrt(new_v) assert self._clipping_threshold is not None if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom cond = _Upd(cond, 'x', x) subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, 'new_m', new_m) # updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, 'subtrahend', subtrahend) # 5 (factored) # var_update = _Wrap(tf.assign_sub, var, subtrahend) # updates.append(var_update) return is_finite_checks, stats
def ComputePredictions(self, encoder_outputs, pronunciations, is_inference=False): """Computes the predictions from the encoder_outputs, updating losses. Despite the name, this function does the bulk of the decoding and loss computation, incrementing the loss at each time step. Args: encoder_outputs: a NestedMap consisting of outputs of the FeatureNeighborhoodEncoder with encoded - encoding of the input spelling neighbor_pronunciations_encoded - encodings of the neighbor prons neighbor_pronunciations_encoded - encodings of the neighbor spellings state - encoder state to which has been added dec_input - seed output for the decoder [*, 1] tensor consisting of sentence start indices (corresponding to "<s>") pronunciations: NestedMap with pronunciations - [*, max_pronunciation_len] tensor of pronunciations is_inference: If False then uses teacher forcing else does autoregression. Returns: NestedMap with loss, per_sequence_losses,labels, a [*, max_pronunciation_len] tensor of predictions, and attention ([*, max_pronunciation_len, max_spelling_len]), and neighbor_attention ([*, max_pronunciation_len, max_neighbors]) tensors, along with the raw batch passed through from the encoder. """ p = self.params targets = pronunciations.pronunciations t_len = int(targets.get_shape().as_list()[1]) t_idx = tf.constant(0) attention = tf.TensorArray(dtype=tf.float32, size=t_len) neighbor_attention = tf.TensorArray(dtype=tf.float32, size=t_len) outputs = tf.TensorArray(dtype=tf.float32, size=t_len) loop_cond = lambda t_idx, ts, *_: tf.less(t_idx, t_len) dec_input = tf.convert_to_tensor([p.start] * p.input.batch_size) state = encoder_outputs.state # pylint: disable=missing-docstring def loop_body(t_idx, dec_input, attention, neighbor_attention, state, outputs): decoder_result = self.Decode(encoder_outputs, dec_input, state) outputs = outputs.write(t_idx, decoder_result.predictions) attention = attention.write(t_idx, decoder_result.attention_weights) neighbor_attention = neighbor_attention.write( t_idx, tf.cast(decoder_result.neighbor_attention_weights, dtype=tf.float32)) if is_inference: dec_input = tf.cast(tf.argmax(decoder_result.predictions, 1), tf.int32) else: dec_input = targets[:, t_idx] t_idx = t_idx + 1 state = decoder_result.state return t_idx, dec_input, attention, neighbor_attention, state, outputs _, _, attention, neighbor_attention, state, outputs = tf.while_loop( loop_cond, loop_body, loop_vars=[ t_idx, dec_input, attention, neighbor_attention, state, outputs ]) outputs = tf.transpose(outputs.stack(), [1, 0, 2]) labels = tf.argmax(outputs, axis=-1) mask = tf.cast(tf.math.logical_not(tf.math.equal(targets, 0)), dtype=tf.float32) loss = self._loss_object(targets, outputs, sample_weight=mask) loss = tf.reduce_sum(loss, axis=1) per_sequence_losses = (loss / t_len) loss = tf.reduce_mean(per_sequence_losses) predictions = py_utils.NestedMap() predictions.loss = loss predictions.per_sequence_losses = per_sequence_losses predictions.labels = labels predictions.attention = tf.transpose(tf.squeeze(attention.stack()), perm=[1, 0, 2]) if p.use_neighbors: predictions.neighbor_attention = tf.transpose(tf.squeeze( neighbor_attention.stack()), perm=[1, 0, 2]) else: predictions.neighbor_attention = tf.squeeze( neighbor_attention.stack()) # Expose this for subsequent data analysis predictions.batch = encoder_outputs.batch return predictions
def _resource_apply_dense(self, grad, var): if grad is None: tf.logging.warning('Gradient is None for variable %s' % var.name) return [] grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') cond = tf.constant(True) def _Upd(c, x): if not self._cond_is_finite: return c c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x))) c = tf.math.logical_and( c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x)))) return c def _Wrap(fn, x, y): if not self._cond_is_finite: return fn(x, y) return tf.cond(cond, lambda: fn(x, y), lambda: x) with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, grad_squared) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) updates = [] if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, new_vr) cond = _Upd(cond, new_vc) vr_update = _Wrap(tf.assign, vr, new_vr) vc_update = _Wrap(tf.assign, vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, new_v) v_update = _Wrap(tf.assign, v, new_v) updates.append(v_update) x = grad * tf.math.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, new_m) updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, subtrahend) var_update = _Wrap(tf.assign_sub, var, subtrahend) updates.append(var_update) return tf.group(*updates)
def ComputeLoss(self, theta, predictions, input_batch): """Compute loss for the sparse detector model v1. Args: theta: A `.NestedMap` object containing variable values of this task. predictions: A `.NestedMap` object containing residuals and classification_logits. input_batch: A `.NestedMap` expected to contain cell_center_xyz, cell_points_xyz, cell_feature, anchor_bboxes, anchor_localization_residuals, assigned_gt_labels, and assigned_cls_mask. See class doc string for details. Returns: Two dicts: - A dict containing str keys and (metric, weight) pairs as values, where one of the keys is expected to be 'loss'. - A dict containing arbitrary tensors describing something about each training example, where the first dimension of each tensor is the batch index. """ p = self.params batch_size, num_centers = py_utils.GetShape( input_batch.cell_center_xyz, 2) # Assert shapes of inputs. anchor_bboxes = py_utils.HasShape( input_batch.anchor_bboxes, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) anchor_localization_residuals = py_utils.HasShape( input_batch.anchor_localization_residuals, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) predicted_residuals = py_utils.HasShape( predictions.residuals, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) assigned_gt_labels = py_utils.HasShape( input_batch.assigned_gt_labels, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) predicted_classification_logits = py_utils.HasShape( predictions.classification_logits, [ batch_size, num_centers, p.num_anchor_bboxes_per_center, p.num_classes ]) # assigned_cls_mask is for weighting the classification loss. # Ignored targets will have their mask = 0; this happens when their IOU is # not high enough to be a foreground object and not low enough to be # background. class_weights = py_utils.HasShape( input_batch.assigned_cls_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) class_weights = tf.reshape( class_weights, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1]) # Broadcast per class loss weights. For each anchor, there are num_classes # prediction heads, we weight the outputs of these heads by the per class # loss weights. per_class_loss_weight = tf.constant([[[p.per_class_loss_weight]]], dtype=tf.float32) per_class_loss_weight = py_utils.HasShape(per_class_loss_weight, [1, 1, 1, p.num_classes]) class_weights *= per_class_loss_weight class_weights = py_utils.HasShape(class_weights, [ batch_size, num_centers, p.num_anchor_bboxes_per_center, p.num_classes ]) # We use assigned_reg_mask for masking the regression loss. # Only foreground objects will have assigned_reg_mask = 1. reg_weights = py_utils.HasShape( input_batch.assigned_reg_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) reg_weights = tf.reshape( reg_weights, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1]) if p.loss_norm_type == LossNormType.NORM_BY_NUM_POS_PER_CENTER: # Compute number of positive anchors per example. foreground_mask = py_utils.HasShape( input_batch.assigned_reg_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) # Sum to get the number of foreground anchors for each example. loss_normalization = tf.reduce_sum(foreground_mask, axis=2) loss_normalization = tf.maximum(loss_normalization, tf.ones_like(loss_normalization)) # Reshape for broadcasting. loss_normalization = tf.reshape(loss_normalization, [batch_size, num_centers, 1, 1]) # Normalize so that the loss is independent of # centers. loss_normalization *= num_centers class_weights /= loss_normalization reg_weights /= loss_normalization classification_loss = py_utils.SigmoidCrossEntropyFocalLoss( logits=predicted_classification_logits, labels=tf.one_hot(assigned_gt_labels, p.num_classes), alpha=p.focal_loss_alpha, gamma=p.focal_loss_gamma) # Apply mask. classification_loss *= class_weights # TODO(jngiam): Consider normalizing by num_foreground_anchors for each # example instead. This would match the 1/N_positive normalization in # point pillars. # Reduce sum over centers, boxes and classes. classification_loss = tf.reduce_sum(classification_loss, axis=[1, 2, 3]) # Reduce mean over batch. classification_loss = tf.reduce_mean(classification_loss) # Localization regression loss with Huber loss (SmoothL1). regression_loc_and_dims_loss = self._utils_3d.ScaledHuberLoss( labels=anchor_localization_residuals[..., :6], predictions=predicted_residuals[..., :6], delta=p.huber_loss_delta) # Rotation loss is computed on a transform on rotation_delta. For a # direction aware loss, we simply wrap the angles to -pi to pi; for a loss # that is symmetric to direction (i.e., rotating by pi), we use a sin # transform. rotation_delta_transform = tf.sin if p.direction_aware_rot_loss: rotation_delta_transform = functools.partial(geometry.WrapAngleRad, min_val=-np.pi, max_val=np.pi) rotation_delta = (predicted_residuals[..., 6:] - anchor_localization_residuals[..., 6:]) regression_rotation_loss = self._utils_3d.ScaledHuberLoss( labels=tf.zeros_like(rotation_delta), predictions=rotation_delta_transform(rotation_delta), delta=p.huber_loss_delta) reg_loc_loss = regression_loc_and_dims_loss[..., :3] reg_dim_loss = regression_loc_and_dims_loss[..., 3:6] gt_bboxes = self._utils_3d.ResidualsToBBoxes( anchor_bboxes, anchor_localization_residuals, min_angle_rad=-np.pi, max_angle_rad=np.pi) predicted_bboxes = self._utils_3d.ResidualsToBBoxes( anchor_bboxes, predicted_residuals, min_angle_rad=-np.pi, max_angle_rad=np.pi) # Apply mask to individual losses. # # And then reduce sum over centers, boxes, residuals, and batch # and divide by the batch_size. regression_rotation_loss *= reg_weights reg_rot_loss = tf.reduce_sum(regression_rotation_loss) / batch_size reg_loc_loss *= reg_weights reg_loc_loss = tf.reduce_sum(reg_loc_loss) / batch_size reg_dim_loss *= reg_weights reg_dim_loss = tf.reduce_sum(reg_dim_loss) / batch_size # Do not create corner loss graph if weight is 0.0 # TODO(bcyang): Remove condition after fixing corner loss NaN issue if p.corner_loss_weight != 0.0: reg_corner_loss = self._utils_3d.CornerLoss( gt_bboxes=gt_bboxes, predicted_bboxes=predicted_bboxes) reg_corner_loss = tf.expand_dims(reg_corner_loss, axis=-1) reg_corner_loss *= reg_weights reg_corner_loss = tf.reduce_sum(reg_corner_loss) / batch_size else: reg_corner_loss = 0.0 # Sum components of regression loss. regression_loss = (p.location_loss_weight * reg_loc_loss + p.dimension_loss_weight * reg_dim_loss + p.rotation_loss_weight * reg_rot_loss + p.corner_loss_weight * reg_corner_loss) # Compute total loss. total_loss = (p.loss_weight_localization * regression_loss + p.loss_weight_classification * classification_loss) metrics_dict = py_utils.NestedMap({ 'loss': (total_loss, batch_size), 'loss/regression': (regression_loss, batch_size), 'loss/regression/loc': (reg_loc_loss, batch_size), 'loss/regression/dim': (reg_dim_loss, batch_size), 'loss/regression/rot': (reg_rot_loss, batch_size), 'loss/regression/corner': (reg_corner_loss, batch_size), 'loss/classification': (classification_loss, batch_size), }) # Calculate dimension errors dimension_errors_dict = self._BBoxDimensionErrors( gt_bboxes, predicted_bboxes, reg_weights) metrics_dict.update(dimension_errors_dict) per_example_dict = py_utils.NestedMap({ 'residuals': predicted_residuals, 'classification_logits': predicted_classification_logits, 'predicted_bboxes': predicted_bboxes, 'gt_bboxes': gt_bboxes, 'reg_weights': reg_weights, }) return metrics_dict, per_example_dict
def ComputePredictions(self, theta, input_batch): """Computes predictions for `input_batch`. Args: theta: A `.NestedMap` object containing variable values of this task. input_batch: A `.NestedMap` object containing input tensors to this tower. Returns: A `.NestedMap` contains logits - [b, nx, ny, nz, na, 7 + num_classes] """ p = self.params input_batch.Transform(lambda x: (x.shape, x.shape.num_elements())).VLog( 0, 'input_batch shapes: ') bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4) # Process points to concatenate a set of fixed features (e.g., # add means, centers, normalize points to means). num_features = 3 + p.num_laser_features pillar_points = py_utils.HasShape(input_batch.pillar_points, [bs, -1, -1, num_features]) _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4) pillar_xyz = pillar_points[..., :3] pillar_means = tf.reduce_mean(pillar_xyz, axis=2, keep_dims=True) pillar_feats = pillar_points[..., 3:] pillar_centers = py_utils.HasShape(input_batch.pillar_centers, [bs, -1, 1, 3]) pillar_concat = tf.concat(axis=3, values=[ pillar_xyz - pillar_means, pillar_feats, tf.tile(pillar_means, [1, 1, npoints, 1]), tf.tile(pillar_centers, [1, 1, npoints, 1]) ]) # Featurize pillars. pillar_features = self.featurizer.FProp(theta.featurizer, pillar_concat) # Convert back to the dense grid. pillar_locations = py_utils.HasShape(input_batch.pillar_locations, [bs, npillars, 3]) dense_features = _SparseToDense(grid_shape=(nx, ny, nz), locations=pillar_locations, feats=pillar_features) # Backbone tf.logging.vlog(1, 'dense_features.shape = %s', dense_features.shape) act = self.backbone.FProp(theta.backbone, dense_features) tf.logging.vlog(1, 'act.shape = %s', act.shape) # Convert the output of the backbone into class logits and regression # residuals using two different layers. class_detection = self.class_detector.FProp(theta.class_detector, act) reg_detection = self.regression_detector.FProp( theta.regression_detector, act) bs, nx, ny, _ = py_utils.GetShape(class_detection, 4) predicted_classification_logits = tf.reshape( class_detection, [bs, nx, ny, p.grid_size_z, p.num_anchors, p.num_classes]) predicted_residuals = tf.reshape( reg_detection, [bs, nx, ny, p.grid_size_z, p.num_anchors, 7]) if p.squash_rotation_predictions: predicted_rotations = predicted_residuals[..., 6:] predicted_rotations = np.pi * tf.tanh(predicted_rotations) predicted_residuals = tf.concat( [predicted_residuals[..., :6], predicted_rotations], axis=-1) if p.oracle_location or p.oracle_dimension or p.oracle_rotation: gt_residuals = py_utils.HasShape( input_batch.anchor_localization_residuals, [bs, nx, ny, p.grid_size_z, p.num_anchors, 7]) # Replace the predicted components with the ground truth if needed. if p.oracle_location: location = gt_residuals[..., 0:3] else: location = predicted_residuals[..., 0:3] if p.oracle_dimension: dimension = gt_residuals[..., 3:6] else: dimension = predicted_residuals[..., 3:6] if p.oracle_rotation: rotation = gt_residuals[..., 6:] else: rotation = predicted_residuals[..., 6:] predicted_residuals = tf.concat([location, dimension, rotation], axis=-1) ret = py_utils.NestedMap({ 'residuals': predicted_residuals, 'classification_logits': predicted_classification_logits, }) if p.direction_classifier_weight > 0.0: predicted_dir = self.direction_classifier.FProp( theta.direction_classifier, act) predicted_dir = tf.reshape( predicted_dir, [bs, nx, ny, p.grid_size_z, p.num_anchors, 2]) ret.predicted_dir = predicted_dir return ret
def Top2GatingOnLogits(inputs, paddings, logits, num_devices, experts_dim, expert_capacity_dim, fprop_dtype, use_xla_sharding=True, second_expert_policy='all', second_expert_threshold=0.0, legacy_mtf_behavior=True, capacity_factor=None): """Computes Top-2 gating for Mixture-of-Experts. There are two expected usages of this function: 1. used with xla_sharding. In this case, 'inputs' corresponds to a sharded tensor across multiple tpu cores. The operations within this function are automatically sharded/replicated across tpu cores. 2. used within ML-Pathways. In this case, 'inputs' is always local to one tpu core. All computations below are carried out on one tpu core only. This function tries to dispatch examples across tpu cores in such a way that each expert is assigned no more than 'expert_capacity_dim' number of examples. Below ` indicates common way of splitting along mesh dimension. Dimensions cheat sheet: G: group_dim S: group_size_dim E: number of experts C: capacity per expert M: model_dim (same as input_dim, same as output_dim) B: original batch_dim L: original sequence_length_dim Note that for local_dispatch original batch BLM is reshaped into GSM, each group `g = 0...G-1` is being dispatched independently. Args: inputs: G`SM Tensor. paddings: G`S Tensor. logits: G`SE Tensor. num_devices: number of MoE devices for local dispatch experts_dim: number of experts. expert_capacity_dim: number of examples per minibatch(group) per expert. Each example is typically a vector of size input_dim, representing embedded token or an element of Transformer layer output. fprop_dtype: activations datatype to use. use_xla_sharding: bool, True if this function is used for the xla_sharding case. second_expert_policy: 'all', 'sampling' or 'random'. - 'all': we greedily pick the 2nd expert. - 'sampling': we sample the 2nd expert from the softmax. - 'random': we optionally 'random'-ize dispatch to second-best expert proportional to (weight / second_expert_threshold). second_expert_threshold: threshold for probability normalization for second_expert_policy == 'random'. legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly. capacity_factor: if set, increases expert_capacity_dim to at least (group_size * capacity_factor) / experts_dim where `group_size` is the size of G dimension of `inputs`. If the value of expert_capacity_dim is already big enough no change is made. TODO(lepikhin): get rid of the legacy_mtf_behavior flag. Returns: A tuple (aux_loss, combine_tensor, dispatch_tensor). - aux_loss: auxiliary loss, for equalizing the expert assignment ratios. - combine_tensor: G`SEC Tensor for combining expert outputs. - dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to experts. """ del inputs # inputs is currently not used. raw_gates = tf.nn.softmax(logits) # along E dim if capacity_factor is not None: # Determine expert capacity automatically depedning on the input size. group_size_dim = int(logits.shape[1]) auto_expert_capacity = int((group_size_dim * capacity_factor) / experts_dim) if expert_capacity_dim < auto_expert_capacity: expert_capacity_dim = auto_expert_capacity # Round up to a multiple of 4 to avoid possible padding. while expert_capacity_dim % 4: expert_capacity_dim += 1 tf.logging.info( 'Setting expert_capacity_dim=%r (capacity_factor=%r ' 'group_size_dim=%r experts_dim=%r name_scope=%r)', expert_capacity_dim, capacity_factor, group_size_dim, experts_dim, tf.get_default_graph().get_name_scope()) tpu_summary.scalar('expert_capacity', expert_capacity_dim) # top first and second gate value and expert index for each input # # GSK Tensors, K=2 def _MaybeSplit(x): if use_xla_sharding: return Split(x, 0, num_devices) else: return x def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name): over_capacity = tf.reduce_sum( tf.cast( tf.greater_equal(mask * position_in_expert, capacity), mask.dtype)) over_capacity_ratio = over_capacity / tf.reduce_sum(mask) py_utils.AddTpuSummaryTensor(name, over_capacity_ratio) tpu_summary.scalar(name, over_capacity_ratio, while_loop_reduce='mean') # As pointed out by zhifengc@ this method needs to be refactored. lepikhin@ # and krikun@ will: # - expand moe_spmd_test to compare Adafactor updates, slots on TPU # including 2x2 with sharding # # - add more tests for policy="random" # # - add single step test for full size WMT model on CPU # # and then break this function into modules. # # GS index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32) index_1 = _MaybeSplit(index_1) tpu_summary.tensor('index_1', index_1) # GSE mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype) mask_1 = _MaybeSplit(mask_1) density_1_proxy = raw_gates importance = tf.ones_like(mask_1[:, :, 0]) if paddings is not None: importance = 1.0 - paddings mask_1 *= tf.expand_dims(importance, -1) density_1_proxy *= tf.expand_dims(importance, -1) gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1) gates_without_top_1 = raw_gates * (1.0 - mask_1) if second_expert_policy == 'sampling': # We directly sample the 2nd expert index from the softmax over of the 2nd # expert by getting rid of the 1st expert already selected above. To do so, # we set a very negative value to the logit corresponding to the 1st expert. # Then we sample from the softmax (categorical) distribution using the # Gumbel max trick. noise = _MaybeSplit(tf.random.uniform(logits.shape, dtype=logits.dtype)) # Generates standard Gumbel(0, 1) noise, GSE Tensors noise = -tf.math.log(-tf.math.log(noise)) very_negative_logits = _MaybeSplit( (tf.ones_like(logits) * logits.dtype.max * tf.constant(-0.7, dtype=logits.dtype))) # Gets rid of the first expert by setting its logit to be very negative updated_logits = _MaybeSplit( tf.where(mask_1 > 0.0, very_negative_logits, logits)) # Adds the Gumbel noise to the updated logits noised_logits = _MaybeSplit(updated_logits + noise) # Picks the index of the largest noised logit as the 2nd expert. This is # equivalent to sampling from the softmax over the 2nd experts. index_2 = tf.math.argmax(noised_logits, axis=-1, output_type=tf.int32) else: index_2 = tf.math.argmax(gates_without_top_1, axis=-1, output_type=tf.int32) index_2 = _MaybeSplit(index_2) mask_2 = tf.one_hot(index_2, experts_dim, dtype=fprop_dtype) mask_2 = _MaybeSplit(mask_2) if paddings is not None: mask_2 *= tf.expand_dims(importance, -1) gate_2 = tf.einsum('GSE,GSE->GS', gates_without_top_1, mask_2) if legacy_mtf_behavior: # cl/298510175 moved this branch for gate_{1,2} denom calculation here. # # For policy=random, it's better to nomalize gate_{1,2} before taking # capacity into account and before potentially dropping second expert. # # According to mean_xent (http://short/_NzbZ5rINr5): # MoE_512_102xen_PolicyAll_298510175 # MoE_512_102xen_PolicyRandom_298510175 # # vs pre-cl/298510175 # MoE_512_102xen_PolicyRandom # MoE_512_102xen_PolicyAll # # it substantially improves policy=random with threshold=0.5 which # historically was better than policy="all" # # Also confirmed this by decoding # nmt_train/m4/data/es_en/test.txt # nmt_train/m4/data/ru_en/test.txt # nmt_train/m4/data/zh_en/test.txt # and improving BLEU # # moe_decode.MoE_512_102xen_PolicyRandom_298510175-160000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102 # 0.421443 # 0.327102 # 0.315693 # vs # moe_decode.feb18_non_fig_snapshot_2626_MoE_512_102xen_PolicyRandom-190000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102 # 0.399232 # 0.310606 # 0.288229 # # Additional comparison, see mean_xent http://short/_YHccOhQtdu with # legacy_mtf_behavior=False models # 3 - MoE_512_102xen_PolicyAll_LegacyFalse # 6 - MoE_512_102xen_PolicyRandom_LegacyFalse # shows that policy="random" gets worse with legacy_mtf_behavior=False, and # is similar to pre-cl/298510175 # 4 - MoE_512_102xen_PolicyRandom # # gate_1 can become 0 due to Expert being out of capacity. # # gate_2 can become 0 due to # second_expert_policy == 'random' # or "out of capacity" scenario. # # Here we renormalize regardless of cases above. denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # We reshape the mask as [X*S, E], and compute cumulative sums of # assignment indicators for each expert index e \in 0..E-1 independently. # First occurrence of assignment indicator is excluded, see exclusive=True # flag below. position_in_expert_1 = tf.cumsum(mask_1, exclusive=True, axis=1) # GS Tensor capacity = tf.cast(expert_capacity_dim, dtype=position_in_expert_1.dtype) # GE Tensor (reducing S out of GSE tensor mask_1) # density_1[:, e] represents assignment ratio (num assigned / total) to # expert e as top_1 expert without taking capacity into account. if legacy_mtf_behavior: density_denom = 1.0 else: density_denom = tf.reduce_mean( importance, axis=(1))[:, tf.newaxis] + 1e-6 density_1 = tf.reduce_mean(mask_1, axis=(1)) / density_denom # density_1_proxy[:, e] represents mean of raw_gates for expert e, including # those of examples not assigned to e with top_k. density_1_proxy = tf.reduce_mean(density_1_proxy, axis=1) / density_denom # The MoE paper (https://arxiv.org/pdf/1701.06538.pdf) uses an aux loss of # reduce_mean(density_1_proxy * density_1_proxy). Here we replace one of # the density_1_proxy with the discrete density_1 following # mesh_tensorflow/transformer/moe.py?rcl=283569345. aux_loss = tf.reduce_mean(density_1_proxy * density_1) # element-wise aux_loss *= experts_dim * experts_dim # const coefficient # Add the over capacity ratio for expert 1 _CreateOverCapacityRatioSummary(mask_1, position_in_expert_1, capacity, 'over_capacity_1_ratio') mask_1 *= tf.cast(tf.less(position_in_expert_1, capacity), dtype=mask_1.dtype) position_in_expert_1 = tf.einsum('GSE,GSE->GS', position_in_expert_1, mask_1) # How many examples in this sequence go to this expert mask_1_count = tf.einsum('GSE->GE', mask_1) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = tf.einsum('GSE->GS', mask_1) if second_expert_policy == 'all' or second_expert_policy == 'sampling': pass elif second_expert_policy == 'random': # gate_2 is between 0 and 1, reminder: # # raw_gates = tf.nn.softmax(logits) # index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32) # mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype) # gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1) # # E.g. if gate_2 exceeds second_expert_threshold, then we definitely # dispatch to second-best expert. Otherwise we dispatch with probability # proportional to (gate_2 / threshold). # sampled_2 = tf.less( _MaybeSplit(tf.random.uniform(gate_2.shape, dtype=gate_2.dtype)), (gate_2 / max(second_expert_threshold, 1e-9))) gate_2 *= tf.cast(sampled_2, gate_2.dtype) mask_2 *= tf.cast(tf.expand_dims(sampled_2, -1), mask_2.dtype) else: raise ValueError(second_expert_policy) position_in_expert_2 = tf.cumsum( mask_2, exclusive=True, axis=1) + tf.expand_dims(mask_1_count, 1) # Add the over capacity ratio for expert 2 _CreateOverCapacityRatioSummary(mask_2, position_in_expert_2, capacity, 'over_capacity_2_ratio') mask_2 *= tf.cast(tf.less(position_in_expert_2, capacity), mask_2.dtype) position_in_expert_2 = tf.einsum('GSE,GSE->GS', position_in_expert_2, mask_2) mask_2_flat = tf.reduce_sum(mask_2, axis=-1) # Equivalent non-einsum implementation: # # position_in_expert_2 *= mask_2 # position_in_expert_2 = tf.reduce_sum( # position_in_expert_2, axis=-1, name='position_in_expert_2') gate_1 *= mask_1_flat gate_2 *= mask_2_flat if not legacy_mtf_behavior: denom = gate_1 + gate_2 # To avoid divide by 0. denom = tf.where(denom > 0, denom, tf.ones_like(denom)) gate_1 /= denom gate_2 /= denom # GSC Tensor b = tf.one_hot( tf.cast(position_in_expert_1, dtype=tf.int32), expert_capacity_dim, dtype=fprop_dtype, name='one_hot_b_0') # GSE Tensor a = tf.expand_dims(gate_1 * mask_1_flat, -1) * tf.one_hot( index_1, experts_dim, dtype=fprop_dtype) # GSEC Tensor first_part_of_combine_tensor = tf.einsum( 'GSE,GSC->GSEC', a, b, name='first_part_of_combine_tensor') # GSC Tensor b = tf.one_hot( tf.cast(position_in_expert_2, dtype=tf.int32), expert_capacity_dim, dtype=fprop_dtype, name='one_hot_b_1') # GSE Tensor a = tf.expand_dims(gate_2 * mask_2_flat, -1) * tf.one_hot( index_2, experts_dim, dtype=fprop_dtype) second_part_of_combine_tensor = tf.einsum( 'GSE,GSC->GSEC', a, b, name='second_part_of_combine_tensor') # GSEC Tensor combine_tensor = ( first_part_of_combine_tensor + second_part_of_combine_tensor) combine_tensor = _MaybeSplit(combine_tensor) # GSEC Tensor dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), fprop_dtype) dispatch_tensor = _MaybeSplit(dispatch_tensor) # TODO(yonghui): compute and return per-group aux_loss. return aux_loss, combine_tensor, dispatch_tensor