Exemple #1
0
  def BeamSearch(self, x, y, decoder_reduce_sum=False):
    for i in range(3):
      with tf.name_scope('encoder%03d' % i):
        x = tf.identity(x)
        y = tf.identity(y)
        x = x + 1
        tpu_summary.scalar('x_mean', tf.reduce_mean(x))
        tpu_summary.scalar('y_mean', tf.reduce_mean(y))

    def DecoderStep(x, y):
      for i in range(3):
        with tf.name_scope('decoder%03d' % i):
          x = tf.identity(x)
          y = tf.identity(y)
          y = y + 1
          if decoder_reduce_sum:
            tpu_summary.scalar(
                'x_mean', tf.reduce_mean(x), while_loop_reduce='sum')
            tpu_summary.scalar(
                'y_mean', tf.reduce_mean(y), while_loop_reduce='sum')
          else:
            tpu_summary.scalar('x_mean', tf.reduce_mean(x))
            tpu_summary.scalar('y_mean', tf.reduce_mean(y))
      return x, y

    def DecoderCond(x, y):
      del x, y
      return True

    (x, y) = tf.while_loop(
        cond=DecoderCond,
        body=DecoderStep,
        loop_vars=(x, y),
        maximum_iterations=10)
    return x, y
Exemple #2
0
 def DecoderStep(x, y):
   for i in range(3):
     with tf.name_scope('decoder%03d' % i):
       x = tf.identity(x)
       y = tf.identity(y)
       y = y + 1
       if decoder_reduce_sum:
         tpu_summary.scalar(
             'x_mean', tf.reduce_mean(x), while_loop_reduce='sum')
         tpu_summary.scalar(
             'y_mean', tf.reduce_mean(y), while_loop_reduce='sum')
       else:
         tpu_summary.scalar('x_mean', tf.reduce_mean(x))
         tpu_summary.scalar('y_mean', tf.reduce_mean(y))
   return x, y
    def _AveNeigh(self, spellings, pronunciations, theta, batch_size):
        p = self.params
        spellings = tf.reshape(spellings, (batch_size * p.max_neighbors, -1))
        pronunciations = tf.reshape(pronunciations,
                                    (batch_size * p.max_neighbors, -1))

        spell_inp = py_utils.NestedMap({
            "ids":
            spellings,
            "paddings":
            self._GetPaddings(spellings, dtype=tf.int32),
        })

        pron_inp = py_utils.NestedMap({
            "ids":
            pronunciations,
            "paddings":
            self._GetPaddings(pronunciations, dtype=tf.int32),
        })

        if p.use_neigh_id_emb:
            # Add the same ID embeddings to both spelling and pron so that the
            # model knows how they pair up.
            neigh_ids = tf.range(p.max_neighbors)[:, tf.newaxis]
            spell_inp["task_ids"] = tf.tile(neigh_ids,
                                            [batch_size, p.max_spelling_len])
            pron_inp["task_ids"] = tf.tile(
                neigh_ids, [batch_size, p.max_pronunciation_len])

        spell_enc_out = self.spell_encoder.FProp(theta.spell_encoder,
                                                 spell_inp)
        pron_enc_out = self.spell_encoder.FProp(theta.pron_encoder, pron_inp)

        spell_enc = tf.reshape(
            spell_enc_out["encoded"],
            (p.max_spelling_len, batch_size, p.max_neighbors, p.enc_units))
        spell_enc = tf.reduce_mean(spell_enc, axis=0)

        pron_enc = tf.reshape(pron_enc_out["encoded"],
                              (p.max_pronunciation_len, batch_size,
                               p.max_neighbors, p.enc_units))
        pron_enc = tf.reduce_mean(pron_enc, axis=0)

        spell_enc = tf.transpose(spell_enc, (1, 0, 2))
        pron_enc = tf.transpose(pron_enc, (1, 0, 2))
        padding = tf.zeros((p.max_neighbors, batch_size))

        return [spell_enc, pron_enc], [padding, padding]
Exemple #4
0
 def FProp(self, x, y):
   for i in range(3):
     with tf.name_scope('encoder%03d' % i):
       x = tf.identity(x)
       y = tf.identity(y)
       x = x + 1
       tpu_summary.scalar('x_mean', tf.reduce_mean(x))
       tpu_summary.scalar('y_mean', tf.reduce_mean(y))
   for i in range(3):
     with tf.name_scope('decoder%03d' % i):
       x = tf.identity(x)
       y = tf.identity(y)
       y = y + 1
       tpu_summary.scalar('x_mean', tf.reduce_mean(x))
       tpu_summary.scalar('y_mean', tf.reduce_mean(y))
   return x, y
Exemple #5
0
 def AddSummary(self, lr, optimizer, var_grad):
     summary_utils.scalar('adagrad_lr', lr)
     for v, _ in var_grad.Flatten():
         slot = optimizer.get_slot(v, 'accumulator')
         assert slot is not None
         summary_utils.scalar('optimizer/adagrad_accum_%s' % v.name,
                              tf.reduce_mean(slot))
Exemple #6
0
 def ComputePredictions(self, theta, input_batch):
     # Forward through layers.
     act = self.extract.FProp(theta.extract, input_batch.data)
     # Avg pool
     act = tf.reduce_mean(act, axis=[1, 2])
     logits = self.softmax.Logits(theta.softmax, act)
     return py_utils.NestedMap(act=act, logits=logits)
    def _ConcatAveNeigh(self, spellings, pronunciations, theta, batch_size):
        p = self.params
        spellings = tf.reshape(spellings, (batch_size * p.max_neighbors, -1))
        pronunciations = tf.reshape(pronunciations,
                                    (batch_size * p.max_neighbors, -1))

        # ->(batch_size * max_neighbors, max_spelling_len + max_pronunciation_len)
        neigh_info = tf.concat([spellings, pronunciations], axis=1)

        # TODO(llion): Add task ids to concatenated info?
        if p.use_neigh_id_emb:
            raise NotImplementedError()
        neigh_inp = py_utils.NestedMap({
            "ids":
            neigh_info,
            "paddings":
            self._GetPaddings(neigh_info, dtype=tf.int32),
        })

        neigh_out = self.spell_encoder.FProp(theta.spell_encoder, neigh_inp)

        neigh_enc = tf.reshape(neigh_out["encoded"],
                               (p.max_spelling_len + p.max_pronunciation_len,
                                batch_size, p.max_neighbors, p.enc_units))
        neigh_enc = tf.reduce_mean(neigh_enc, axis=0)

        neigh_enc = tf.transpose(neigh_enc, (1, 0, 2))
        padding = tf.zeros((p.max_neighbors, batch_size))

        return [neigh_enc], [padding]
Exemple #8
0
    def _verify_timestep_counts(self, num_splits):
        num_micro_batches = 8
        batch_size = 16
        with self.session(graph=tf.Graph()) as sess:
            tf.set_random_seed(1245)
            inputs = tf.random_uniform([batch_size, 8, 8, 1], seed=12345)
            net = _BuildDummyPipelineCnn(num_splits=num_splits,
                                         num_micro_batches=num_micro_batches)
            endpoints = net.FPropDefaultTheta(inputs)
            if isinstance(endpoints, (list, tuple)):
                logits, aux_logits = endpoints
            else:
                logits = endpoints
                aux_logits = None
            loss = tf.reduce_mean(logits)
            grads = tf.gradients(loss, tf.trainable_variables())
            grad_norm = tf.sqrt(py_utils.SumSquared(grads))
            ts = net.GetAccumulatorValues().Flatten()

            sess.run(tf.global_variables_initializer())
            grad_norm_val, ts_vals = sess.run([grad_norm, ts])
            test_utils.CompareToGoldenSingleFloat(self, 0.268087,
                                                  grad_norm_val)
            # Accumulator values should be equal to number of time steps in pipeline.
            for ts_val in list(ts_vals):
                expected_ts = num_micro_batches if num_splits > 1 else 1
                self.assertEqual(ts_val, expected_ts)
            if aux_logits is not None:
                aux_logit_tensor = sess.run(aux_logits)
                self.assertEqual(aux_logit_tensor.shape, (batch_size, 8, 8, 1))
    def testCheckNumerics(self):
        checked = py_utils.CheckNumerics(
            tf.convert_to_tensor([2.0, 3.0], tf.float32))
        self.assertListEqual([2.0, 3.0], checked.numpy().tolist())

        with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 'NaN'):
            py_utils.CheckNumerics(
                tf.reduce_mean(tf.convert_to_tensor([], tf.float32)))
Exemple #10
0
    def FProp(self, theta, input_batch):
        # pyformat: disable
        """Compute features for the pillars and convert them back to a dense grid.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` object containing input tensors. Following
        keys are required:

        - grid_num_points: Integer tensor with shape [batch size, nx, ny, nz],
          where nx, ny, nz corresponds to the grid sizes (i.e., number of voxels
          in each axis dimension).
        - pillar_points: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3 + num_laser_features]
        - pillar_centers: Float tensor with shape [batch size, num_pillars,
          num_points_per_pillar, 3]
        - pillar_locations: Float tensor with shape [batch size, num_pillars, 3]

    Returns:
      The dense features with shape [b, nx, ny, nz * fdims].
    """
        # pyformat: enable
        p = self.params
        bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4)
        # Process points to concatenate a set of fixed features (e.g.,
        # add means, centers, normalize points to means).
        num_features = 3 + p.num_laser_features
        pillar_points = py_utils.HasShape(input_batch.pillar_points,
                                          [bs, -1, -1, num_features])
        _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4)
        pillar_xyz = pillar_points[..., :3]
        pillar_means = tf.reduce_mean(pillar_xyz, axis=2, keep_dims=True)
        pillar_feats = pillar_points[..., 3:]
        pillar_centers = py_utils.HasShape(input_batch.pillar_centers,
                                           [bs, -1, 1, 3])
        pillar_concat = tf.concat(axis=3,
                                  values=[
                                      pillar_xyz - pillar_means, pillar_feats,
                                      tf.tile(pillar_means,
                                              [1, 1, npoints, 1]),
                                      tf.tile(pillar_centers,
                                              [1, 1, npoints, 1])
                                  ])
        # Featurize pillars.
        pillar_features = self.featurizer.FProp(theta.featurizer,
                                                pillar_concat)

        # Convert back to the dense grid.
        pillar_locations = py_utils.HasShape(input_batch.pillar_locations,
                                             [bs, npillars, 3])
        dense_features = SparseToDense(grid_shape=(nx, ny, nz),
                                       locations=pillar_locations,
                                       feats=pillar_features)
        return dense_features
Exemple #11
0
 def _GetExpertDist(self, theta, inputs, *args):
   """Get the task id from inputs tensors."""
   # TODO(huangyp): support the more general case when batch size is not 1.
   # Input shape can be either [batch, length, dim] or [length, batch, dim]
   reshaped_inputs = tf.reshape(inputs, [-1, self.params.cond_dim])
   if self.params.nonzeros_mean:
     per_example_emb = tf.reduce_sum(reshaped_inputs, 0)
     nonzeros = tf.cast(tf.count_nonzero(reshaped_inputs, 0), dtype=tf.float32)
     per_example_emb /= (nonzeros + 1e-10)
   else:
     per_example_emb = tf.reduce_mean(reshaped_inputs, 0)
   expert_dist = tf.nn.sigmoid(tf.einsum('i,ij->j', per_example_emb, theta.w))
   return expert_dist
Exemple #12
0
 def _AddAttentionSummaries(self, name, atten_probs):
   # Plots attention prob summaries for joint network.
   # TODO(ankurbpn): Check why op wasn't compiling on TPUs.
   p = self.params
   if not py_utils.use_tpu(
   ) and p.allow_attention_summaries and atten_probs is not None:
     atten_shape = tf.shape(atten_probs)
     atten_probs = tf.reshape(
         atten_probs, [atten_shape[0], atten_shape[1], -1, atten_shape[-1]])
     # Only plots first example of the batch.
     atten_probs = tf.reduce_mean(atten_probs[0:1, :, :, :], 1)
     self._AddAttenProbsImageSummary(name, atten_probs)
     self._AddAttenProbsHistogramSummary(name, atten_probs)
Exemple #13
0
    def ComputeLoss(self, theta, predictions, input_batch):
        output_batch = predictions
        ctc_loss = tf.nn.ctc_loss(
            input_batch.tgt.labels,
            output_batch.encoded,
            py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1),
            py_utils.LengthsFromBitMask(output_batch.padding, 0),
            logits_time_major=True,
            blank_index=self.params.blank_index)

        # ctc_loss.shape = (B)
        total_loss = tf.reduce_mean(ctc_loss)
        per_sequence_loss = {'loss': ctc_loss}
        return dict(loss=(total_loss, 1.0)), per_sequence_loss
 def get_accuracy(self, loss, pred, target):
   p = self.params
   int_dtype = pred.dtype
   target = tf.cast(target, int_dtype)
   pad_id = int(p.input.feature_neighborhood_input.batch_opts.pad_value)
   mask = tf.cast(tf.math.not_equal(target, pad_id), int_dtype)
   pred *= mask
   num_non_zero = tf.cast(tf.reduce_sum(mask), tf.float32)
   equal = tf.math.equal(pred, target)
   loss["accuracy_per_example"] = (tf.reduce_mean(
       tf.cast(tf.reduce_all(equal, axis=1), tf.float32)), p.input.batch_size)
   equal = tf.cast(equal, tf.float32)
   equal *= tf.cast(mask, tf.float32)
   loss["accuracy_per_char"] = (tf.reduce_sum(equal) / num_non_zero,
                                p.input.batch_size)
Exemple #15
0
    def _verify_timestep_counts(self,
                                num_splits,
                                auto_partition=False,
                                micro_batch_size=None):
        num_micro_batches = 8
        batch_size = 16
        with self.session(graph=tf.Graph()) as sess:
            tf.random.set_seed(1245)
            inputs = tf.random.uniform([batch_size, 8, 8, 1], seed=12345)
            if auto_partition:
                layers = [
                    _SimpyLayer.Params().Set(name='layer_{}'.format(i))
                    for i in range(16)
                ]
                net = PipeliningLayer.Params().Set(
                    name='pipeline',
                    num_micro_batches=num_micro_batches,
                    cell_tpl=_Partition(layers, num_splits,
                                        tshape.Shape([batch_size, 8, 8,
                                                      1]))).Instantiate()
            else:
                net = _BuildDummyPipelineCnn(
                    num_splits=num_splits,
                    micro_batch_size=micro_batch_size,
                    num_micro_batches=num_micro_batches)
            endpoints = net.FPropDefaultTheta(inputs)
            if isinstance(endpoints, (list, tuple)):
                logits, aux_logits = endpoints
            else:
                logits = endpoints
                aux_logits = None
            loss = tf.reduce_mean(logits)
            grads = tf.gradients(loss, tf.trainable_variables())
            grad_norm = tf.sqrt(py_utils.SumSquared(grads))
            ts = net.GetAccumulatorValues().Flatten()

            sess.run(tf.global_variables_initializer())
            grad_norm_val, ts_vals = sess.run([grad_norm, ts])
            test_utils.CompareToGoldenSingleFloat(self, 0.268087,
                                                  grad_norm_val)
            # Accumulator values should be equal to number of time steps in pipeline.
            for ts_val in list(ts_vals):
                expected_ts = num_micro_batches if num_splits > 1 else 1
                self.assertEqual(ts_val, expected_ts)
            if aux_logits is not None:
                aux_logit_tensor = sess.run(aux_logits)
                self.assertEqual(aux_logit_tensor.shape, (batch_size, 8, 8, 1))
Exemple #16
0
def AddAttentionSummaryBatchMajor(attention_tensors,
                                  src_paddings,
                                  tgt_paddings,
                                  transcripts=None,
                                  max_outputs=3):
    """Adds an image summary showing the attention probability matrix and state.

  As opposed to AddAttentionSummary() takes all tensors with batch dimension in
  axis 0.

  Args:
    attention_tensors: A list of 3D tensors shaped [batch_size, target_len,
      source_len] where attention[b, i, j] is the probability for the i-th
      output attending to the j-th input for element b in the batch.
    src_paddings: A tensor of binary paddings shaped [batch, source_len] for the
      source sequence.
    tgt_paddings: A tensor of binary paddings shaped [batch, target_len] for the
      target sequence.
    transcripts: Optional, transcripts shaped [batch, source_len] for the source
      sequence.
    max_outputs: Integer maximum number of elements of the batch to plot.
  """
    name = attention_tensors[0].name + '/Attention'
    if not _ShouldAddSummary():
        return
    with plot.MatplotlibFigureSummary(name, max_outputs=max_outputs) as fig:
        src_lens = SequenceLength(src_paddings)
        tgt_lens = SequenceLength(tgt_paddings)
        for n, atten in enumerate(attention_tensors):
            # Diagnostic metric that decreases as attention picks up.
            max_entropy = tf.log(tf.cast(src_lens, tf.float32))
            max_entropy = tf.expand_dims(tf.expand_dims(max_entropy, -1), -1)
            atten_normalized_entropy = -atten * tf.log(atten +
                                                       1e-10) / max_entropy
            scalar('Attention/average_normalized_entropy/%d' % n,
                   tf.reduce_mean(atten_normalized_entropy))
            args = [atten, src_lens, tgt_lens]
            if transcripts is not None and n == 0:
                args.append(transcripts)
            fig.AddSubplot(args,
                           TrimPaddingAndPlotAttention,
                           title=atten.name,
                           xlabel='Input',
                           ylabel='Output')
Exemple #17
0
def ProcessPillars(input_batch, num_laser_features, featurizer,
                   theta_featurizer):
  """Compute features for the pillars and convert them back to a dense grid.

  Args:
    input_batch: A `.NestedMap` object containing input tensors.
    num_laser_features: Number of laser features (excluding pillar features)
    featurizer: The featurizer layer.
    theta_featurizer: The weights for featurizer.

  Returns:
      The dense features with shape [b, nx, ny, nz * fdims].
  """
  bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4)
  # Process points to concatenate a set of fixed features (e.g.,
  # add means, centers, normalize points to means).
  num_features = 3 + num_laser_features
  pillar_points = py_utils.HasShape(input_batch.pillar_points,
                                    [bs, -1, -1, num_features])
  _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4)
  pillar_xyz = pillar_points[..., :3]
  pillar_means = tf.reduce_mean(pillar_xyz, axis=2, keep_dims=True)
  pillar_feats = pillar_points[..., 3:]
  pillar_centers = py_utils.HasShape(input_batch.pillar_centers, [bs, -1, 1, 3])
  pillar_concat = tf.concat(
      axis=3,
      values=[
          pillar_xyz - pillar_means, pillar_feats,
          tf.tile(pillar_means, [1, 1, npoints, 1]),
          tf.tile(pillar_centers, [1, 1, npoints, 1])
      ])
  # Featurize pillars.
  pillar_features = featurizer.FProp(theta_featurizer, pillar_concat)

  # Convert back to the dense grid.
  pillar_locations = py_utils.HasShape(input_batch.pillar_locations,
                                       [bs, npillars, 3])
  dense_features = SparseToDense(
      grid_shape=(nx, ny, nz),
      locations=pillar_locations,
      feats=pillar_features)
  return dense_features
Exemple #18
0
        def _Gradient(inputs, _, original_grad):

            # Compute the gradients for each loss w.r.t. the inputs.
            # TODO(jngiam): Look into whether TF dedups this computation.
            per_loss_grads = []
            for loss, _ in self._losses:
                per_loss_grad = tf.gradients(loss, self._output_tensor)[0]
                if per_loss_grad is None:
                    tf.logging.warning(
                        'Loss %s did not result in a gradient during '
                        'GradDrop computation.', loss)
                else:
                    per_loss_grads.append(per_loss_grad)

            if not per_loss_grads:
                raise ValueError('No valid gradients for GradDrop.')

            # Multiply the gradients with the inputs.
            grads = per_loss_grads
            if p.use_input_sign_only:
                input_abs = tf.abs(
                    tf.cast(tf.abs(inputs) <= p.epsilon, tf.float32) + inputs)
                grads = [grad * ((inputs) / (input_abs)) for grad in grads]
            else:
                grads = [grad * inputs for grad in grads]

            # Sum gradient over batch, assuming that batch is always on dim 0.
            if p.marginalize_batch_dim:
                grads = [
                    tf.reduce_sum(grad, axis=0, keepdims=True)
                    for grad in grads
                ]

            # First discretize all gradients into their sign values.
            grad_sign_positive = [
                tf.cast(grad > 0.0, tf.float32) for grad in grads
            ]
            grad_sign_negative = [
                tf.cast(grad < 0.0, tf.float32) for grad in grads
            ]

            # Calculate the probability of positive gradients based on equation (1)
            # in the GradDrop paper.
            grad_abs_sum = tf.add_n([tf.abs(grad) for grad in grads])
            prob_pos = (tf.add_n(grads) / (2. * grad_abs_sum + p.epsilon))
            # Implementation of different scales for the keep function. Larger
            # scales result in steeper keep functions.
            prob_pos *= p.keep_prob_function_scale

            if p.keep_prob_function == 'sigmoid':
                # Standard sigmoid has derivative of 0.25 at 0 so the factor of 4.0
                # allows the function scale in sigmoid to be compatible with the
                # function scale in the linear case.
                prob_pos = tf.sigmoid(4.0 * prob_pos)
            elif p.keep_prob_function == 'linear':
                prob_pos += 0.5

            # The main, default mode of GradDrop. Only gradients of one sign are kept,
            # and which sign is calculated via equation (1) of the main paper.
            prob_pos = tf.cast(prob_pos >= tf.random.uniform(prob_pos.shape),
                               tf.float32) - 0.5
            grad_masks = [
                (gsp - gsn) * prob_pos >= 0
                for (gsn, gsp) in zip(grad_sign_negative, grad_sign_positive)
            ]

            # This diag value gives us the percentage of grads which are kept.
            gradmask_diag = [tf.cast(gm, tf.float32) for gm in grad_masks]
            diag = tf.reduce_mean(tf.add_n(gradmask_diag) / len(grad_masks))
            summary_utils.scalar('average_grad_mask', diag)
            leak_ratios = [leak_ratio for _, leak_ratio in self._losses]
            transformed_per_loss_grads = [
                grad * (leak + (1.0 - leak) * tf.cast(grad_mask, tf.float32))
                for (leak, grad,
                     grad_mask) in zip(leak_ratios, per_loss_grads, grad_masks)
            ]

            transformed_grad = tf.cast(tf.add_n(transformed_per_loss_grads),
                                       original_grad.dtype)

            if not p.keep_gradnorm_constant:
                return transformed_grad

            transformed_grad_norm = tf.sqrt(tf.reduce_sum(transformed_grad**2))
            original_grad_norm = tf.sqrt(tf.reduce_sum(original_grad**2))
            return transformed_grad * original_grad_norm / (
                transformed_grad_norm + p.epsilon)
Exemple #19
0
def AddAttentionSummaryBatchMajor(attention_tensors,
                                  src_paddings,
                                  tgt_paddings,
                                  transcripts=None,
                                  max_outputs=3):
    """Adds an image summary showing the attention probability matrix and state.

  As opposed to AddAttentionSummary() takes all tensors with batch dimension in
  axis 0.

  Args:
    attention_tensors: A list of 3D tensors shaped [batch_size, target_len,
      source_len] where attention[b, i, j] is the probability for the i-th
      output attending to the j-th input for element b in the batch.
    src_paddings: A tensor of binary paddings shaped [batch, source_len] for the
      source sequence. Or a list of tensors of the same length as
      attention_tensors with a separate paddings for each entry in
      attention_tensors.
    tgt_paddings: A tensor of binary paddings shaped [batch, target_len] for the
      target sequence. Or a list of tensors of the same length as
      attention_tensors with a separate paddings for each entry in
      attention_tensors.
    transcripts: Optional, transcripts shaped [batch, source_len] for the source
      sequence.
    max_outputs: Integer maximum number of elements of the batch to plot.
  """
    def VerifyLen(paddings):
        length = len(paddings) if isinstance(paddings, list) else 1
        if length != 1 and length != len(attention_tensors):
            raise ValueError('Bad length of paddings list {}'.format(length))

    VerifyLen(src_paddings)
    VerifyLen(tgt_paddings)

    name = attention_tensors[0].name + '/Attention'
    if not _ShouldAddSummary():
        return

    def ToLengths(paddings):
        paddings = paddings if isinstance(paddings, list) else [paddings]
        return [SequenceLength(p) for p in paddings]

    def Get(lengths, i):
        return lengths[0 if len(lengths) == 1 else i]

    src_lens = ToLengths(src_paddings)
    tgt_lens = ToLengths(tgt_paddings)

    with plot.MatplotlibFigureSummary(name,
                                      max_outputs=max_outputs,
                                      gridspec_kwargs={'hspace': 0.3}) as fig:
        for n, atten in enumerate(attention_tensors):
            # Diagnostic metric that decreases as attention picks up.
            max_entropy = tf.log(tf.cast(Get(src_lens, n), tf.float32))
            max_entropy = tf.expand_dims(tf.expand_dims(max_entropy, -1), -1)
            atten_normalized_entropy = -atten * tf.log(atten +
                                                       1e-10) / max_entropy
            scalar('Attention/average_normalized_entropy/%d' % n,
                   tf.reduce_mean(atten_normalized_entropy))
            args = [atten, Get(src_lens, n), Get(tgt_lens, n)]
            if transcripts is not None and n == 0:
                args.append(transcripts)
            fig.AddSubplot(args,
                           TrimPaddingAndPlotAttention,
                           title=atten.name,
                           xlabel='Input',
                           ylabel='Output')
Exemple #20
0
 def sum_embeddings(t):
     reshaped = tf.reshape(
         t, (p.input.batch_size, p.max_neighbors, -1, p.enc_units))
     return tf.reduce_mean(reshaped, axis=2)
Exemple #21
0
    def ComputeLoss(self, theta, predictions, input_batch):
        """Computes loss and other metrics for the given predictions.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      predictions: The output of `ComputePredictions`.
      input_batch: A `.NestedMap` object containing input tensors to this tower.

    Returns:
      A tuple (metrics, per_example_tensors), where
        - `metrics` is a dict of str keys to (metric, weight) values
        - `per_example_tensors` is a dict of str keys to tensors describing each
          training example, where the first dimension of each tensor is the
          batch index.
    """
        p = self.params

        # During TPU training, collect the encodings and ids from all TPUs so the
        # loss can be computed over all query-result pairs in the global batch.
        # To avoid duplicating work, each TPU operates on a non-overlapping
        # slice of these pairs. Specifically, each TPU uses queries drawn from its
        # local batch and results from the global batch.

        # Encodings of the local and global examples, keyed by modality.
        local_flat_encodings = py_utils.NestedMap({
            modality: tf.reshape(predictions[modality].encodings,
                                 [-1, p.joint_embedding_dim])
            for modality in predictions
        })
        global_flat_encodings = tpu_utils.ConcatenateAcrossReplicas(
            local_flat_encodings)

        def _ComputePerQueryLoss(query_modality, result_modality):
            labeler_inputs = label_lib.ExamplePairs.BetweenLocalAndGlobalBatches(
                input_batch,
                query_modality=query_modality,
                result_modality=result_modality)
            labels = p.label_fn(labeler_inputs)

            # [num_queries, num_results]
            flat_similarities = self.score_function(
                local_flat_encodings[query_modality],
                global_flat_encodings[result_modality])

            flat_labels = tf.reshape(labels, flat_similarities.shape)
            # [num_queries]
            return label_lib.MultiLabelContrastiveLoss(
                labels=flat_labels, logits=flat_similarities)

        loss_terms = []
        metrics = {}
        for direction, loss_weight in p.loss_weights.items():
            query_modality, result_modality = direction
            if not loss_weight:
                logging.info('Skipping %s retrieval', direction)
                continue
            per_query_losses = _ComputePerQueryLoss(query_modality,
                                                    result_modality)
            mean_per_query_loss = tf.reduce_mean(per_query_losses)
            loss_terms.append(loss_weight * mean_per_query_loss)
            metrics['loss_{}_to_{}'.format(
                query_modality, result_modality)] = (mean_per_query_loss, 1)

        regularization_losses = utils.CollectRegularizationLosses(self)
        if p.regularization_loss_weight and regularization_losses:
            tf.logging.info('Adding TF1 regularization loss: %s',
                            regularization_losses)
            total_reg_loss = tf.reduce_sum(regularization_losses)
            loss_terms.append(p.regularization_loss_weight * total_reg_loss)
            metrics['loss_regularization'] = (total_reg_loss, 1)

        loss = tf.add_n(loss_terms)
        metrics['loss'] = (loss, 1)
        return metrics, {}
Exemple #22
0
    def try_apply_dense(self, grad, var):
        assert grad is not None

        cond = tf.constant(True)
        is_finite_checks = []
        stats = {}

        grad_dtype = var.dtype  # TODO(lepikhin): add to params
        grad = tf.cast(grad, grad_dtype)
        factored_dims = self._factored_dims(var.shape.as_list())
        if factored_dims:
            vr = self.get_slot(var, 'vr')
            vc = self.get_slot(var, 'vc')
        else:
            v = self.get_slot(var, 'v')
        if self._beta1:
            m = self.get_slot(var, 'm')

        def _Upd(c, k, x):
            stats[k] = x
            is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x)))
            return c

        with tf.variable_scope(var.name[:-2] + '/Adafactor'):
            grad_squared = tf.math.square(grad) + tf.cast(
                self._epsilon1, grad_dtype)
            cond = _Upd(cond, 'grad_squared', grad_squared)  # 0 (factored)
            decay_rate = tf.cast(self._decay_rate, var.dtype)
            old_val = tf.identity(
                var)  # TODO(lepikhin): introduce gradient dtype
            assert self._multiply_by_parameter_scale
            lr = GetLrValue(self._learning_rate)
            if self._multiply_by_parameter_scale:
                parameter_scale = self._parameter_scale(old_val)
                cond = _Upd(cond, 'parameter_scale',
                            parameter_scale)  # 1 (factored)
                update_scale = self._parameter_scale(old_val) * tf.cast(
                    lr, grad_dtype)

            else:
                update_scale = lr
            mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype)
            update_scale = tf.cast(update_scale, grad_dtype)
            if factored_dims:
                d0, d1 = factored_dims
                vr_axis, vc_axis = d0, d1
                grad_squared_row_mean = tf.reduce_mean(grad_squared,
                                                       axis=vr_axis)
                grad_squared_col_mean = tf.reduce_mean(grad_squared,
                                                       axis=vc_axis)
                # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                cond = _Upd(cond, 'new_vr', new_vr)  # 2 (factored)
                cond = _Upd(cond, 'new_vc', new_vc)  # 3 (factored)
                # vr_update = _Wrap(tf.assign, vr, new_vr)
                # vc_update = _Wrap(tf.assign, vc, new_vc)
                # updates.extend([vr_update, vc_update])
                long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
                r_factor = tf.math.rsqrt(new_vr / long_term_mean)
                c_factor = tf.math.rsqrt(new_vc)
                mult = tf.expand_dims(r_factor, vr_axis) * tf.expand_dims(
                    c_factor, vc_axis)
                cond = _Upd(cond, 'mult', mult)  # 4 (factored)
                x = grad * mult
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                cond = _Upd(cond, 'new_v', new_v)
                # v_update = _Wrap(tf.assign, v, new_v)
                # updates.append(v_update)
                x = grad * tf.math.rsqrt(new_v)

            assert self._clipping_threshold is not None

            if self._clipping_threshold is not None:
                clipping_denom = tf.maximum(
                    tf.constant(1.0, grad_dtype),
                    py_utils.ReduceRms(x) /
                    tf.constant(self._clipping_threshold, grad_dtype))
                x /= clipping_denom
            cond = _Upd(cond, 'x', x)
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) +
                         subtrahend *
                         tf.constant(1.0 - self._beta1, dtype=grad_dtype))
                subtrahend = new_m
                cond = _Upd(cond, 'new_m', new_m)
                # updates.append(_Wrap(tf.assign, m, new_m))

            # It is critical to use assign_sub instead of tf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            cond = _Upd(cond, 'subtrahend', subtrahend)  # 5 (factored)

            # var_update = _Wrap(tf.assign_sub, var, subtrahend)
            # updates.append(var_update)

            return is_finite_checks, stats
Exemple #23
0
    def ComputePredictions(self,
                           encoder_outputs,
                           pronunciations,
                           is_inference=False):
        """Computes the predictions from the encoder_outputs, updating losses.

    Despite the name, this function does the bulk of the decoding and loss
    computation, incrementing the loss at each time step.

    Args:
      encoder_outputs: a NestedMap consisting of outputs of the
        FeatureNeighborhoodEncoder with  encoded - encoding of the input
        spelling
        neighbor_pronunciations_encoded - encodings of the neighbor prons
        neighbor_pronunciations_encoded - encodings of the neighbor spellings
        state - encoder state to which has been added dec_input - seed output
        for the decoder [*, 1] tensor consisting of sentence start indices
        (corresponding to "<s>")
      pronunciations: NestedMap with pronunciations - [*, max_pronunciation_len]
        tensor of pronunciations
      is_inference: If False then uses teacher forcing else does autoregression.

    Returns:
      NestedMap with loss, per_sequence_losses,labels, a
      [*, max_pronunciation_len] tensor of predictions, and attention
      ([*, max_pronunciation_len, max_spelling_len]), and
      neighbor_attention ([*, max_pronunciation_len, max_neighbors])
      tensors, along with the raw batch passed through from the encoder.
    """
        p = self.params
        targets = pronunciations.pronunciations
        t_len = int(targets.get_shape().as_list()[1])
        t_idx = tf.constant(0)
        attention = tf.TensorArray(dtype=tf.float32, size=t_len)
        neighbor_attention = tf.TensorArray(dtype=tf.float32, size=t_len)

        outputs = tf.TensorArray(dtype=tf.float32, size=t_len)

        loop_cond = lambda t_idx, ts, *_: tf.less(t_idx, t_len)

        dec_input = tf.convert_to_tensor([p.start] * p.input.batch_size)
        state = encoder_outputs.state

        # pylint: disable=missing-docstring
        def loop_body(t_idx, dec_input, attention, neighbor_attention, state,
                      outputs):
            decoder_result = self.Decode(encoder_outputs, dec_input, state)

            outputs = outputs.write(t_idx, decoder_result.predictions)
            attention = attention.write(t_idx,
                                        decoder_result.attention_weights)
            neighbor_attention = neighbor_attention.write(
                t_idx,
                tf.cast(decoder_result.neighbor_attention_weights,
                        dtype=tf.float32))

            if is_inference:
                dec_input = tf.cast(tf.argmax(decoder_result.predictions, 1),
                                    tf.int32)
            else:
                dec_input = targets[:, t_idx]
            t_idx = t_idx + 1
            state = decoder_result.state
            return t_idx, dec_input, attention, neighbor_attention, state, outputs

        _, _, attention, neighbor_attention, state, outputs = tf.while_loop(
            loop_cond,
            loop_body,
            loop_vars=[
                t_idx, dec_input, attention, neighbor_attention, state, outputs
            ])

        outputs = tf.transpose(outputs.stack(), [1, 0, 2])
        labels = tf.argmax(outputs, axis=-1)
        mask = tf.cast(tf.math.logical_not(tf.math.equal(targets, 0)),
                       dtype=tf.float32)
        loss = self._loss_object(targets, outputs, sample_weight=mask)
        loss = tf.reduce_sum(loss, axis=1)
        per_sequence_losses = (loss / t_len)
        loss = tf.reduce_mean(per_sequence_losses)
        predictions = py_utils.NestedMap()
        predictions.loss = loss
        predictions.per_sequence_losses = per_sequence_losses
        predictions.labels = labels
        predictions.attention = tf.transpose(tf.squeeze(attention.stack()),
                                             perm=[1, 0, 2])
        if p.use_neighbors:
            predictions.neighbor_attention = tf.transpose(tf.squeeze(
                neighbor_attention.stack()),
                                                          perm=[1, 0, 2])
        else:
            predictions.neighbor_attention = tf.squeeze(
                neighbor_attention.stack())
        # Expose this for subsequent data analysis
        predictions.batch = encoder_outputs.batch
        return predictions
Exemple #24
0
    def _resource_apply_dense(self, grad, var):
        if grad is None:
            tf.logging.warning('Gradient is None for variable %s' % var.name)
            return []

        grad_dtype = var.dtype  # TODO(lepikhin): add to params
        grad = tf.cast(grad, grad_dtype)
        factored_dims = self._factored_dims(var.shape.as_list())
        if factored_dims:
            vr = self.get_slot(var, 'vr')
            vc = self.get_slot(var, 'vc')
        else:
            v = self.get_slot(var, 'v')
        if self._beta1:
            m = self.get_slot(var, 'm')

        cond = tf.constant(True)

        def _Upd(c, x):
            if not self._cond_is_finite:
                return c
            c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x)))
            c = tf.math.logical_and(
                c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x))))
            return c

        def _Wrap(fn, x, y):
            if not self._cond_is_finite:
                return fn(x, y)
            return tf.cond(cond, lambda: fn(x, y), lambda: x)

        with tf.variable_scope(var.name[:-2] + '/Adafactor'):
            grad_squared = tf.math.square(grad) + tf.cast(
                self._epsilon1, grad_dtype)
            cond = _Upd(cond, grad_squared)
            decay_rate = tf.cast(self._decay_rate, var.dtype)
            old_val = tf.identity(
                var)  # TODO(lepikhin): introduce gradient dtype
            lr = GetLrValue(self._learning_rate)
            if self._multiply_by_parameter_scale:
                update_scale = self._parameter_scale(old_val) * tf.cast(
                    lr, grad_dtype)
            else:
                update_scale = lr
            mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype)
            update_scale = tf.cast(update_scale, grad_dtype)
            updates = []
            if factored_dims:
                d0, d1 = factored_dims
                vr_axis, vc_axis = d0, d1
                grad_squared_row_mean = tf.reduce_mean(grad_squared,
                                                       axis=vr_axis)
                grad_squared_col_mean = tf.reduce_mean(grad_squared,
                                                       axis=vc_axis)
                # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
                new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate
                # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
                new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate
                cond = _Upd(cond, new_vr)
                cond = _Upd(cond, new_vc)
                vr_update = _Wrap(tf.assign, vr, new_vr)
                vc_update = _Wrap(tf.assign, vc, new_vc)
                updates.extend([vr_update, vc_update])
                long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
                r_factor = tf.math.rsqrt(new_vr / long_term_mean)
                c_factor = tf.math.rsqrt(new_vc)
                x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims(
                    c_factor, vc_axis)
            else:
                new_v = v * decay_rate + grad_squared * mixing_rate
                cond = _Upd(cond, new_v)
                v_update = _Wrap(tf.assign, v, new_v)
                updates.append(v_update)
                x = grad * tf.math.rsqrt(new_v)
            if self._clipping_threshold is not None:
                clipping_denom = tf.maximum(
                    tf.constant(1.0, grad_dtype),
                    py_utils.ReduceRms(x) /
                    tf.constant(self._clipping_threshold, grad_dtype))
                x /= clipping_denom
            subtrahend = x * update_scale
            if self._beta1:
                new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) +
                         subtrahend *
                         tf.constant(1.0 - self._beta1, dtype=grad_dtype))
                subtrahend = new_m
                cond = _Upd(cond, new_m)
                updates.append(_Wrap(tf.assign, m, new_m))
            # It is critical to use assign_sub instead of tf.assign(var - subtrahend)
            #  for the case of bfloat16 activations, so as to avoid repeatedly
            #  rounding the slice value, which results in poor quality.
            cond = _Upd(cond, subtrahend)
            var_update = _Wrap(tf.assign_sub, var, subtrahend)
            updates.append(var_update)
            return tf.group(*updates)
Exemple #25
0
    def ComputeLoss(self, theta, predictions, input_batch):
        """Compute loss for the sparse detector model v1.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      predictions: A `.NestedMap` object containing residuals and
        classification_logits.
      input_batch: A `.NestedMap` expected to contain cell_center_xyz,
        cell_points_xyz, cell_feature, anchor_bboxes,
        anchor_localization_residuals, assigned_gt_labels, and
        assigned_cls_mask. See class doc string for details.

    Returns:
      Two dicts:

      - A dict containing str keys and (metric, weight) pairs as values, where
        one of the keys is expected to be 'loss'.
      - A dict containing arbitrary tensors describing something about each
        training example, where the first dimension of each tensor is the batch
        index.
    """
        p = self.params

        batch_size, num_centers = py_utils.GetShape(
            input_batch.cell_center_xyz, 2)

        # Assert shapes of inputs.
        anchor_bboxes = py_utils.HasShape(
            input_batch.anchor_bboxes,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])
        anchor_localization_residuals = py_utils.HasShape(
            input_batch.anchor_localization_residuals,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])
        predicted_residuals = py_utils.HasShape(
            predictions.residuals,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])

        assigned_gt_labels = py_utils.HasShape(
            input_batch.assigned_gt_labels,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        predicted_classification_logits = py_utils.HasShape(
            predictions.classification_logits, [
                batch_size, num_centers, p.num_anchor_bboxes_per_center,
                p.num_classes
            ])

        # assigned_cls_mask is for weighting the classification loss.
        # Ignored targets will have their mask = 0; this happens when their IOU is
        # not high enough to be a foreground object and not low enough to be
        # background.
        class_weights = py_utils.HasShape(
            input_batch.assigned_cls_mask,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        class_weights = tf.reshape(
            class_weights,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1])

        # Broadcast per class loss weights. For each anchor, there are num_classes
        # prediction heads, we weight the outputs of these heads by the per class
        # loss weights.
        per_class_loss_weight = tf.constant([[[p.per_class_loss_weight]]],
                                            dtype=tf.float32)
        per_class_loss_weight = py_utils.HasShape(per_class_loss_weight,
                                                  [1, 1, 1, p.num_classes])
        class_weights *= per_class_loss_weight
        class_weights = py_utils.HasShape(class_weights, [
            batch_size, num_centers, p.num_anchor_bboxes_per_center,
            p.num_classes
        ])

        # We use assigned_reg_mask for masking the regression loss.
        # Only foreground objects will have assigned_reg_mask = 1.
        reg_weights = py_utils.HasShape(
            input_batch.assigned_reg_mask,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        reg_weights = tf.reshape(
            reg_weights,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1])

        if p.loss_norm_type == LossNormType.NORM_BY_NUM_POS_PER_CENTER:
            # Compute number of positive anchors per example.
            foreground_mask = py_utils.HasShape(
                input_batch.assigned_reg_mask,
                [batch_size, num_centers, p.num_anchor_bboxes_per_center])

            # Sum to get the number of foreground anchors for each example.
            loss_normalization = tf.reduce_sum(foreground_mask, axis=2)
            loss_normalization = tf.maximum(loss_normalization,
                                            tf.ones_like(loss_normalization))

            # Reshape for broadcasting.
            loss_normalization = tf.reshape(loss_normalization,
                                            [batch_size, num_centers, 1, 1])

            # Normalize so that the loss is independent of # centers.
            loss_normalization *= num_centers
            class_weights /= loss_normalization
            reg_weights /= loss_normalization

        classification_loss = py_utils.SigmoidCrossEntropyFocalLoss(
            logits=predicted_classification_logits,
            labels=tf.one_hot(assigned_gt_labels, p.num_classes),
            alpha=p.focal_loss_alpha,
            gamma=p.focal_loss_gamma)

        # Apply mask.
        classification_loss *= class_weights

        # TODO(jngiam): Consider normalizing by num_foreground_anchors for each
        # example instead. This would match the 1/N_positive normalization in
        # point pillars.

        # Reduce sum over centers, boxes and classes.
        classification_loss = tf.reduce_sum(classification_loss,
                                            axis=[1, 2, 3])

        # Reduce mean over batch.
        classification_loss = tf.reduce_mean(classification_loss)

        # Localization regression loss with Huber loss (SmoothL1).
        regression_loc_and_dims_loss = self._utils_3d.ScaledHuberLoss(
            labels=anchor_localization_residuals[..., :6],
            predictions=predicted_residuals[..., :6],
            delta=p.huber_loss_delta)

        # Rotation loss is computed on a transform on rotation_delta. For a
        # direction aware loss, we simply wrap the angles to -pi to pi; for a loss
        # that is symmetric to direction (i.e., rotating by pi), we use a sin
        # transform.
        rotation_delta_transform = tf.sin
        if p.direction_aware_rot_loss:
            rotation_delta_transform = functools.partial(geometry.WrapAngleRad,
                                                         min_val=-np.pi,
                                                         max_val=np.pi)
        rotation_delta = (predicted_residuals[..., 6:] -
                          anchor_localization_residuals[..., 6:])
        regression_rotation_loss = self._utils_3d.ScaledHuberLoss(
            labels=tf.zeros_like(rotation_delta),
            predictions=rotation_delta_transform(rotation_delta),
            delta=p.huber_loss_delta)

        reg_loc_loss = regression_loc_and_dims_loss[..., :3]
        reg_dim_loss = regression_loc_and_dims_loss[..., 3:6]

        gt_bboxes = self._utils_3d.ResidualsToBBoxes(
            anchor_bboxes,
            anchor_localization_residuals,
            min_angle_rad=-np.pi,
            max_angle_rad=np.pi)
        predicted_bboxes = self._utils_3d.ResidualsToBBoxes(
            anchor_bboxes,
            predicted_residuals,
            min_angle_rad=-np.pi,
            max_angle_rad=np.pi)

        # Apply mask to individual losses.
        #
        # And then reduce sum over centers, boxes, residuals, and batch
        # and divide by the batch_size.
        regression_rotation_loss *= reg_weights
        reg_rot_loss = tf.reduce_sum(regression_rotation_loss) / batch_size

        reg_loc_loss *= reg_weights
        reg_loc_loss = tf.reduce_sum(reg_loc_loss) / batch_size

        reg_dim_loss *= reg_weights
        reg_dim_loss = tf.reduce_sum(reg_dim_loss) / batch_size

        # Do not create corner loss graph if weight is 0.0
        # TODO(bcyang): Remove condition after fixing corner loss NaN issue
        if p.corner_loss_weight != 0.0:
            reg_corner_loss = self._utils_3d.CornerLoss(
                gt_bboxes=gt_bboxes, predicted_bboxes=predicted_bboxes)
            reg_corner_loss = tf.expand_dims(reg_corner_loss, axis=-1)

            reg_corner_loss *= reg_weights
            reg_corner_loss = tf.reduce_sum(reg_corner_loss) / batch_size
        else:
            reg_corner_loss = 0.0

        # Sum components of regression loss.
        regression_loss = (p.location_loss_weight * reg_loc_loss +
                           p.dimension_loss_weight * reg_dim_loss +
                           p.rotation_loss_weight * reg_rot_loss +
                           p.corner_loss_weight * reg_corner_loss)

        # Compute total loss.
        total_loss = (p.loss_weight_localization * regression_loss +
                      p.loss_weight_classification * classification_loss)

        metrics_dict = py_utils.NestedMap({
            'loss': (total_loss, batch_size),
            'loss/regression': (regression_loss, batch_size),
            'loss/regression/loc': (reg_loc_loss, batch_size),
            'loss/regression/dim': (reg_dim_loss, batch_size),
            'loss/regression/rot': (reg_rot_loss, batch_size),
            'loss/regression/corner': (reg_corner_loss, batch_size),
            'loss/classification': (classification_loss, batch_size),
        })

        # Calculate dimension errors
        dimension_errors_dict = self._BBoxDimensionErrors(
            gt_bboxes, predicted_bboxes, reg_weights)
        metrics_dict.update(dimension_errors_dict)

        per_example_dict = py_utils.NestedMap({
            'residuals': predicted_residuals,
            'classification_logits': predicted_classification_logits,
            'predicted_bboxes': predicted_bboxes,
            'gt_bboxes': gt_bboxes,
            'reg_weights': reg_weights,
        })

        return metrics_dict, per_example_dict
Exemple #26
0
    def ComputePredictions(self, theta, input_batch):
        """Computes predictions for `input_batch`.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      input_batch: A `.NestedMap` object containing input tensors to this tower.

    Returns:
      A `.NestedMap` contains
        logits - [b, nx, ny, nz, na, 7 + num_classes]
    """
        p = self.params
        input_batch.Transform(lambda x:
                              (x.shape, x.shape.num_elements())).VLog(
                                  0, 'input_batch shapes: ')

        bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4)
        # Process points to concatenate a set of fixed features (e.g.,
        # add means, centers, normalize points to means).
        num_features = 3 + p.num_laser_features
        pillar_points = py_utils.HasShape(input_batch.pillar_points,
                                          [bs, -1, -1, num_features])
        _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4)
        pillar_xyz = pillar_points[..., :3]
        pillar_means = tf.reduce_mean(pillar_xyz, axis=2, keep_dims=True)
        pillar_feats = pillar_points[..., 3:]
        pillar_centers = py_utils.HasShape(input_batch.pillar_centers,
                                           [bs, -1, 1, 3])
        pillar_concat = tf.concat(axis=3,
                                  values=[
                                      pillar_xyz - pillar_means, pillar_feats,
                                      tf.tile(pillar_means,
                                              [1, 1, npoints, 1]),
                                      tf.tile(pillar_centers,
                                              [1, 1, npoints, 1])
                                  ])

        # Featurize pillars.
        pillar_features = self.featurizer.FProp(theta.featurizer,
                                                pillar_concat)

        # Convert back to the dense grid.
        pillar_locations = py_utils.HasShape(input_batch.pillar_locations,
                                             [bs, npillars, 3])
        dense_features = _SparseToDense(grid_shape=(nx, ny, nz),
                                        locations=pillar_locations,
                                        feats=pillar_features)

        # Backbone
        tf.logging.vlog(1, 'dense_features.shape = %s', dense_features.shape)
        act = self.backbone.FProp(theta.backbone, dense_features)
        tf.logging.vlog(1, 'act.shape = %s', act.shape)

        # Convert the output of the backbone into class logits and regression
        # residuals using two different layers.
        class_detection = self.class_detector.FProp(theta.class_detector, act)
        reg_detection = self.regression_detector.FProp(
            theta.regression_detector, act)
        bs, nx, ny, _ = py_utils.GetShape(class_detection, 4)
        predicted_classification_logits = tf.reshape(
            class_detection,
            [bs, nx, ny, p.grid_size_z, p.num_anchors, p.num_classes])
        predicted_residuals = tf.reshape(
            reg_detection, [bs, nx, ny, p.grid_size_z, p.num_anchors, 7])

        if p.squash_rotation_predictions:
            predicted_rotations = predicted_residuals[..., 6:]
            predicted_rotations = np.pi * tf.tanh(predicted_rotations)
            predicted_residuals = tf.concat(
                [predicted_residuals[..., :6], predicted_rotations], axis=-1)

        if p.oracle_location or p.oracle_dimension or p.oracle_rotation:
            gt_residuals = py_utils.HasShape(
                input_batch.anchor_localization_residuals,
                [bs, nx, ny, p.grid_size_z, p.num_anchors, 7])

            # Replace the predicted components with the ground truth if needed.
            if p.oracle_location:
                location = gt_residuals[..., 0:3]
            else:
                location = predicted_residuals[..., 0:3]

            if p.oracle_dimension:
                dimension = gt_residuals[..., 3:6]
            else:
                dimension = predicted_residuals[..., 3:6]

            if p.oracle_rotation:
                rotation = gt_residuals[..., 6:]
            else:
                rotation = predicted_residuals[..., 6:]
            predicted_residuals = tf.concat([location, dimension, rotation],
                                            axis=-1)

        ret = py_utils.NestedMap({
            'residuals':
            predicted_residuals,
            'classification_logits':
            predicted_classification_logits,
        })

        if p.direction_classifier_weight > 0.0:
            predicted_dir = self.direction_classifier.FProp(
                theta.direction_classifier, act)
            predicted_dir = tf.reshape(
                predicted_dir, [bs, nx, ny, p.grid_size_z, p.num_anchors, 2])
            ret.predicted_dir = predicted_dir

        return ret
Exemple #27
0
def Top2GatingOnLogits(inputs,
                       paddings,
                       logits,
                       num_devices,
                       experts_dim,
                       expert_capacity_dim,
                       fprop_dtype,
                       use_xla_sharding=True,
                       second_expert_policy='all',
                       second_expert_threshold=0.0,
                       legacy_mtf_behavior=True,
                       capacity_factor=None):
  """Computes Top-2 gating for Mixture-of-Experts.

  There are two expected usages of this function:

  1. used with xla_sharding. In this case, 'inputs' corresponds to a sharded
     tensor across multiple tpu cores. The operations within this function are
     automatically sharded/replicated across tpu cores.
  2. used within ML-Pathways. In this case, 'inputs' is always local to one tpu
     core. All computations below are carried out on one tpu core only. This
     function tries to dispatch examples across tpu cores in such a way that
     each expert is assigned no more than 'expert_capacity_dim' number of
     examples.

  Below ` indicates common way of splitting along mesh dimension.

  Dimensions cheat sheet:

    G: group_dim
    S: group_size_dim
    E: number of experts
    C: capacity per expert
    M: model_dim (same as input_dim, same as output_dim)
    B: original batch_dim
    L: original sequence_length_dim

  Note that for local_dispatch original batch BLM is reshaped into GSM, each
  group `g = 0...G-1` is being dispatched independently.

  Args:
    inputs: G`SM Tensor.
    paddings: G`S Tensor.
    logits: G`SE Tensor.
    num_devices: number of MoE devices for local dispatch
    experts_dim: number of experts.
    expert_capacity_dim: number of examples per minibatch(group) per expert.
      Each example is typically a vector of size input_dim, representing
      embedded token or an element of Transformer layer output.
    fprop_dtype: activations datatype to use.
    use_xla_sharding: bool, True if this function is used for the xla_sharding
      case.
    second_expert_policy: 'all', 'sampling' or 'random'.

      - 'all': we greedily pick the 2nd expert.
      - 'sampling': we sample the 2nd expert from the softmax.
      - 'random': we optionally 'random'-ize dispatch to second-best expert
        proportional to (weight / second_expert_threshold).

    second_expert_threshold: threshold for probability normalization for
      second_expert_policy == 'random'.
    legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly.
    capacity_factor: if set, increases expert_capacity_dim to at least
      (group_size * capacity_factor) / experts_dim
      where `group_size` is the size of G dimension of `inputs`. If the
      value of expert_capacity_dim is already big enough no change is made.

  TODO(lepikhin): get rid of the legacy_mtf_behavior flag.

  Returns:
    A tuple (aux_loss, combine_tensor, dispatch_tensor).

    - aux_loss: auxiliary loss, for equalizing the expert assignment ratios.
    - combine_tensor: G`SEC Tensor for combining expert outputs.
    - dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to
      experts.
  """
  del inputs  # inputs is currently not used.
  raw_gates = tf.nn.softmax(logits)  # along E dim

  if capacity_factor is not None:
    # Determine expert capacity automatically depedning on the input size.
    group_size_dim = int(logits.shape[1])
    auto_expert_capacity = int((group_size_dim * capacity_factor) / experts_dim)
    if expert_capacity_dim < auto_expert_capacity:
      expert_capacity_dim = auto_expert_capacity
      # Round up to a multiple of 4 to avoid possible padding.
      while expert_capacity_dim % 4:
        expert_capacity_dim += 1
      tf.logging.info(
          'Setting expert_capacity_dim=%r (capacity_factor=%r '
          'group_size_dim=%r experts_dim=%r name_scope=%r)',
          expert_capacity_dim, capacity_factor, group_size_dim, experts_dim,
          tf.get_default_graph().get_name_scope())
    tpu_summary.scalar('expert_capacity', expert_capacity_dim)

  # top first and second gate value and expert index for each input
  #
  # GSK Tensors, K=2
  def _MaybeSplit(x):
    if use_xla_sharding:
      return Split(x, 0, num_devices)
    else:
      return x

  def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name):
    over_capacity = tf.reduce_sum(
        tf.cast(
            tf.greater_equal(mask * position_in_expert, capacity), mask.dtype))
    over_capacity_ratio = over_capacity / tf.reduce_sum(mask)
    py_utils.AddTpuSummaryTensor(name, over_capacity_ratio)
    tpu_summary.scalar(name, over_capacity_ratio, while_loop_reduce='mean')

  # As pointed out by zhifengc@ this method needs to be refactored. lepikhin@
  # and krikun@ will:
  #   - expand moe_spmd_test to compare Adafactor updates, slots on TPU
  #   including 2x2 with sharding
  #
  #   - add more tests for policy="random"
  #
  #   - add single step test for full size WMT model on CPU
  #
  # and then break this function into modules.
  #
  # GS
  index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32)
  index_1 = _MaybeSplit(index_1)
  tpu_summary.tensor('index_1', index_1)

  # GSE
  mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype)
  mask_1 = _MaybeSplit(mask_1)
  density_1_proxy = raw_gates

  importance = tf.ones_like(mask_1[:, :, 0])

  if paddings is not None:
    importance = 1.0 - paddings
    mask_1 *= tf.expand_dims(importance, -1)
    density_1_proxy *= tf.expand_dims(importance, -1)

  gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1)
  gates_without_top_1 = raw_gates * (1.0 - mask_1)

  if second_expert_policy == 'sampling':
    # We directly sample the 2nd expert index from the softmax over of the 2nd
    # expert by getting rid of the 1st expert already selected above. To do so,
    # we set a very negative value to the logit corresponding to the 1st expert.
    # Then we sample from the softmax (categorical) distribution using the
    # Gumbel max trick.
    noise = _MaybeSplit(tf.random.uniform(logits.shape, dtype=logits.dtype))
    # Generates standard Gumbel(0, 1) noise, GSE Tensors
    noise = -tf.math.log(-tf.math.log(noise))
    very_negative_logits = _MaybeSplit(
        (tf.ones_like(logits) * logits.dtype.max *
         tf.constant(-0.7, dtype=logits.dtype)))
    # Gets rid of the first expert by setting its logit to be very negative
    updated_logits = _MaybeSplit(
        tf.where(mask_1 > 0.0, very_negative_logits, logits))
    # Adds the Gumbel noise to the updated logits
    noised_logits = _MaybeSplit(updated_logits + noise)
    # Picks the index of the largest noised logit as the 2nd expert. This is
    # equivalent to sampling from the softmax over the 2nd experts.
    index_2 = tf.math.argmax(noised_logits, axis=-1, output_type=tf.int32)
  else:
    index_2 = tf.math.argmax(gates_without_top_1, axis=-1, output_type=tf.int32)

  index_2 = _MaybeSplit(index_2)
  mask_2 = tf.one_hot(index_2, experts_dim, dtype=fprop_dtype)
  mask_2 = _MaybeSplit(mask_2)
  if paddings is not None:
    mask_2 *= tf.expand_dims(importance, -1)
  gate_2 = tf.einsum('GSE,GSE->GS', gates_without_top_1, mask_2)

  if legacy_mtf_behavior:
    # cl/298510175 moved this branch for gate_{1,2} denom calculation here.
    #
    # For policy=random, it's better to nomalize gate_{1,2} before taking
    # capacity  into account and before potentially dropping second expert.
    #
    # According to mean_xent (http://short/_NzbZ5rINr5):
    #   MoE_512_102xen_PolicyAll_298510175
    #   MoE_512_102xen_PolicyRandom_298510175
    #
    # vs pre-cl/298510175
    #   MoE_512_102xen_PolicyRandom
    #   MoE_512_102xen_PolicyAll
    #
    # it substantially improves policy=random with threshold=0.5 which
    # historically was better than policy="all"
    #
    # Also confirmed this by decoding
    #   nmt_train/m4/data/es_en/test.txt
    #   nmt_train/m4/data/ru_en/test.txt
    #   nmt_train/m4/data/zh_en/test.txt
    # and improving BLEU
    #
    # moe_decode.MoE_512_102xen_PolicyRandom_298510175-160000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102
    #   0.421443
    #   0.327102
    #   0.315693
    # vs
    # moe_decode.feb18_non_fig_snapshot_2626_MoE_512_102xen_PolicyRandom-190000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102
    #   0.399232
    #   0.310606
    #   0.288229
    #
    # Additional comparison, see mean_xent http://short/_YHccOhQtdu with
    # legacy_mtf_behavior=False models
    #   3 - MoE_512_102xen_PolicyAll_LegacyFalse
    #   6 - MoE_512_102xen_PolicyRandom_LegacyFalse
    # shows that policy="random" gets worse with legacy_mtf_behavior=False, and
    # is similar to pre-cl/298510175
    #   4 - MoE_512_102xen_PolicyRandom
    #
    # gate_1 can become 0 due to Expert being out of capacity.
    #
    # gate_2 can become 0 due to
    #   second_expert_policy == 'random'
    # or "out of capacity" scenario.
    #
    # Here we renormalize regardless of cases above.
    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

  # We reshape the mask as [X*S, E], and compute cumulative sums of
  # assignment indicators for each expert index e \in 0..E-1 independently.
  # First occurrence of assignment indicator is excluded, see exclusive=True
  # flag below.
  position_in_expert_1 = tf.cumsum(mask_1, exclusive=True, axis=1)

  # GS Tensor
  capacity = tf.cast(expert_capacity_dim, dtype=position_in_expert_1.dtype)

  # GE Tensor (reducing S out of GSE tensor mask_1)
  # density_1[:, e] represents assignment ratio (num assigned / total) to
  # expert e as top_1 expert without taking capacity into account.
  if legacy_mtf_behavior:
    density_denom = 1.0
  else:
    density_denom = tf.reduce_mean(
        importance, axis=(1))[:, tf.newaxis] + 1e-6
  density_1 = tf.reduce_mean(mask_1, axis=(1)) / density_denom
  # density_1_proxy[:, e] represents mean of raw_gates for expert e, including
  # those of examples not assigned to e with top_k.
  density_1_proxy = tf.reduce_mean(density_1_proxy, axis=1) / density_denom

  # The MoE paper (https://arxiv.org/pdf/1701.06538.pdf) uses an aux loss of
  # reduce_mean(density_1_proxy * density_1_proxy). Here we replace one of
  # the density_1_proxy with the discrete density_1 following
  # mesh_tensorflow/transformer/moe.py?rcl=283569345.
  aux_loss = tf.reduce_mean(density_1_proxy * density_1)  # element-wise
  aux_loss *= experts_dim * experts_dim  # const coefficient

  # Add the over capacity ratio for expert 1
  _CreateOverCapacityRatioSummary(mask_1, position_in_expert_1, capacity,
                                  'over_capacity_1_ratio')

  mask_1 *= tf.cast(tf.less(position_in_expert_1, capacity), dtype=mask_1.dtype)
  position_in_expert_1 = tf.einsum('GSE,GSE->GS', position_in_expert_1, mask_1)

  # How many examples in this sequence go to this expert
  mask_1_count = tf.einsum('GSE->GE', mask_1)
  # [batch, group] - mostly ones, but zeros where something didn't fit
  mask_1_flat = tf.einsum('GSE->GS', mask_1)

  if second_expert_policy == 'all' or second_expert_policy == 'sampling':
    pass
  elif second_expert_policy == 'random':
    # gate_2 is between 0 and 1, reminder:
    #
    #   raw_gates = tf.nn.softmax(logits)
    #   index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32)
    #   mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype)
    #   gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1)
    #
    # E.g. if gate_2 exceeds second_expert_threshold, then we definitely
    # dispatch to second-best expert. Otherwise we dispatch with probability
    # proportional to (gate_2 / threshold).
    #
    sampled_2 = tf.less(
        _MaybeSplit(tf.random.uniform(gate_2.shape, dtype=gate_2.dtype)),
        (gate_2 / max(second_expert_threshold, 1e-9)))
    gate_2 *= tf.cast(sampled_2, gate_2.dtype)
    mask_2 *= tf.cast(tf.expand_dims(sampled_2, -1), mask_2.dtype)
  else:
    raise ValueError(second_expert_policy)

  position_in_expert_2 = tf.cumsum(
      mask_2, exclusive=True, axis=1) + tf.expand_dims(mask_1_count, 1)

  # Add the over capacity ratio for expert 2
  _CreateOverCapacityRatioSummary(mask_2, position_in_expert_2, capacity,
                                  'over_capacity_2_ratio')

  mask_2 *= tf.cast(tf.less(position_in_expert_2, capacity), mask_2.dtype)
  position_in_expert_2 = tf.einsum('GSE,GSE->GS', position_in_expert_2, mask_2)
  mask_2_flat = tf.reduce_sum(mask_2, axis=-1)

  # Equivalent non-einsum implementation:
  #
  # position_in_expert_2 *= mask_2
  # position_in_expert_2 = tf.reduce_sum(
  #     position_in_expert_2, axis=-1, name='position_in_expert_2')

  gate_1 *= mask_1_flat
  gate_2 *= mask_2_flat

  if not legacy_mtf_behavior:
    denom = gate_1 + gate_2
    # To avoid divide by 0.
    denom = tf.where(denom > 0, denom, tf.ones_like(denom))
    gate_1 /= denom
    gate_2 /= denom

  # GSC Tensor
  b = tf.one_hot(
      tf.cast(position_in_expert_1, dtype=tf.int32),
      expert_capacity_dim,
      dtype=fprop_dtype,
      name='one_hot_b_0')
  # GSE Tensor
  a = tf.expand_dims(gate_1 * mask_1_flat, -1) * tf.one_hot(
      index_1, experts_dim, dtype=fprop_dtype)
  # GSEC Tensor
  first_part_of_combine_tensor = tf.einsum(
      'GSE,GSC->GSEC', a, b, name='first_part_of_combine_tensor')

  # GSC Tensor
  b = tf.one_hot(
      tf.cast(position_in_expert_2, dtype=tf.int32),
      expert_capacity_dim,
      dtype=fprop_dtype,
      name='one_hot_b_1')
  # GSE Tensor
  a = tf.expand_dims(gate_2 * mask_2_flat, -1) * tf.one_hot(
      index_2, experts_dim, dtype=fprop_dtype)
  second_part_of_combine_tensor = tf.einsum(
      'GSE,GSC->GSEC', a, b, name='second_part_of_combine_tensor')

  # GSEC Tensor
  combine_tensor = (
      first_part_of_combine_tensor + second_part_of_combine_tensor)
  combine_tensor = _MaybeSplit(combine_tensor)

  # GSEC Tensor
  dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), fprop_dtype)
  dispatch_tensor = _MaybeSplit(dispatch_tensor)

  # TODO(yonghui): compute and return per-group aux_loss.
  return aux_loss, combine_tensor, dispatch_tensor