示例#1
0
 def testEinsumReplacementBxycBzxBzyc(self):
     with self.session(use_gpu=False, graph=tf.Graph()) as sess:
         a = tf.random_uniform(shape=[20, 7, 4, 3],
                               minval=0,
                               maxval=1,
                               dtype=tf.float32)
         b = tf.random_uniform(shape=[20, 5, 7],
                               minval=0,
                               maxval=1,
                               dtype=tf.float32)
         einsum = tf.einsum('bxyc,bzx->bzyc', a, b)
         p = spectrum_augmenter_on_device.SpectrumAugmenterOnDevice.Params()
         p.name = 'specAug_layers'
         specaug_layer = p.Instantiate()
         replacement = specaug_layer.EinsumBxycBzxBzyc(a, b)
         einsum, replacement = sess.run([einsum, replacement])
         self.assertAllClose(einsum, replacement)
示例#2
0
  def FProp(self, theta, inputs, *args):
    p = self.params
    with tf.name_scope(p.name) as scope:
      expert_dist = self._GetExpertDist(theta, inputs, *args)
      if not self.do_eval:
        summary_utils.histogram('soft_cond_{}'.format(scope), expert_dist)

      # Excludes non-variable extra_theta like global_step.
      var_set = set([key for key, _ in self.body.vars.FlattenItems()])
      values = []
      for key, value in theta.body.FlattenItems():
        if key in var_set and value is not None:
          # Weighted average for all variables created in the body layer.
          value = tf.einsum('i,i...->...', expert_dist, value)
        values.append(value)
      weighted_theta = theta.body.Pack(values)
      return self.body.FProp(weighted_theta, inputs, *args)
示例#3
0
def IsWithinBBox(points, bbox):
    """Checks if points are within a 2-d bbox.

  The function returns true if points are strictly inside the box. It also
  returns true when the points are exactly on the box edges.

  Args:
    points: a float Tensor of shape [..., 2] of points to be tested. The last
      coordinates are (x, y).
    bbox: a float Tensor of shape [..., 4, 2] of bboxes. The last coordinates
      are the four corners of the bbox and (x, y). The corners are assumed to be
      given in counter-clockwise order.

  Returns:
    Tensor: If ``pshape = tf.shape(points)[:-1]`` and
    ``bshape = tf.shape(bbox)[:-2]``, returns a boolean tensor of shape
    ``tf.concat(pshape, bshape)``, where each element is true if the point is
    inside to the corresponding box.  If a point falls exactly on an edge of the
    bbox, it is also true.
  """
    bshape = py_utils.GetShape(bbox)[:-2]
    pshape = py_utils.GetShape(points)[:-1]
    bbox = py_utils.HasShape(bbox, tf.concat([bshape, [4, 2]], axis=0))
    points = py_utils.HasShape(points, tf.concat([pshape, [2]], axis=0))
    # Enumerate all 4 edges:
    v1, v2, v3, v4 = (bbox[..., 0, :], bbox[..., 1, :], bbox[...,
                                                             2, :], bbox[...,
                                                                         3, :])
    v1v2v3_check = tf.reduce_all(_IsCounterClockwiseDirection(v1, v2, v3))
    v2v3v4_check = tf.reduce_all(_IsCounterClockwiseDirection(v2, v3, v4))
    v4v1v2_check = tf.reduce_all(_IsCounterClockwiseDirection(v4, v1, v2))
    v3v4v1_check = tf.reduce_all(_IsCounterClockwiseDirection(v3, v4, v1))
    with tf.control_dependencies([
            py_utils.Assert(v1v2v3_check, [v1, v2, v3]),
            py_utils.Assert(v2v3v4_check, [v3, v3, v4]),
            py_utils.Assert(v4v1v2_check, [v4, v1, v2]),
            py_utils.Assert(v3v4v1_check, [v3, v4, v1])
    ]):
        is_inside = tf.math.logical_and(
            tf.math.logical_and(_IsOnLeftHandSideOrOn(points, v1, v2),
                                _IsOnLeftHandSideOrOn(points, v2, v3)),
            tf.math.logical_and(_IsOnLeftHandSideOrOn(points, v3, v4),
                                _IsOnLeftHandSideOrOn(points, v4, v1)))
    # Swap the last two dimensions.
    is_inside = tf.einsum('...ij->...ji', tf.cast(is_inside, tf.int32))
    return tf.cast(is_inside, tf.bool)
示例#4
0
def softmax_kernel_transformation(data,
                                  is_query,
                                  projection_matrix=None,
                                  numerical_stabilizer=0.000001):
    """Computes random features for the softmax kernel using FAVOR+ mechanism.

  Computes random features for the softmax kernel using FAVOR+ mechanism from
  https://arxiv.org/pdf/2009.14794.pdf.

  Args:
    data: input data tensor of the shape [B, L, H, D], where: B - batch
      dimension, L - attention dimensions, H - heads, D - features.
    is_query: indicates whether input data is a query oor key tensor.
    projection_matrix: random Gaussian matrix of shape [M, D], where M stands
      for the number of random features and each D x D sub-block has pairwise
      orthogonal rows.
    numerical_stabilizer: small positive constant for numerical stability.

  Returns:
    Corresponding kernel feature map.
  """
    projection_matrix = tf.cast(projection_matrix, data.dtype)
    data_normalizer = 1.0 / tf.math.sqrt(
        (tf.math.sqrt(tf.dtypes.cast(data.shape[-1], data.dtype))))
    ratio = 1.0 / tf.math.sqrt(
        tf.dtypes.cast(projection_matrix.shape[0], data.dtype))
    data_dash = tf.einsum("blhd,md->blhm", data_normalizer * data,
                          projection_matrix)
    diag_data = tf.math.square(data)
    diag_data = tf.math.reduce_sum(diag_data,
                                   axis=tf.keras.backend.ndim(data) - 1)
    diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
    diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
    if is_query:
        last_dims_t = (len(data_dash.shape) - 1, )
        data_dash = ratio * (tf.math.exp(
            data_dash - diag_data -
            tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True)) +
                             numerical_stabilizer)
    else:
        data_dash = ratio * (tf.math.exp(data_dash - diag_data -
                                         tf.math.reduce_max(data_dash)) +
                             numerical_stabilizer)

    return data_dash
示例#5
0
    def testBatchMakeRotationMatrix(self):
        batch_size, num_points, num_boxes = 10, 8, 1
        points = tf.random.uniform((batch_size, num_boxes, num_points, 3))
        boxes = tf.random.uniform((batch_size, num_boxes, 7))

        # Rotate the points
        rot_matrix = geometry.BatchMakeRotationMatrix(-boxes[..., -1])
        rot_matrix = tf.reshape(rot_matrix, [batch_size, num_boxes, 3, 3])
        rotated_points = tf.einsum('bnpm,bnmc->bnpc', points, rot_matrix)
        with self.session():
            actual_points, actual_rotated_points = self.evaluate(
                (points, rotated_points))

        # Points are the same on the z-axis (no rotation).
        self.assertAllClose(actual_points[..., 2], actual_rotated_points[...,
                                                                         2])
        # Points are transformed, and different.
        self.assertNotAllClose(actual_points, actual_rotated_points)
示例#6
0
def _RelPositionBiasCausal(query, abs_pos_emb):
    """Computes relative position bias for causal self attention."""
    _, t, n, h = py_utils.GetShape(query)

    abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

    # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
    # Retain only half and change order to [T-1, T-2, ... 0]
    # [T, N, H]
    abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t]

    # [B, N, T, L=T]
    term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

    # Perform shifting.
    term_bd = tf.reverse(term_bd, [2, 3])
    term_bd = RelShift(term_bd)
    return tf.reverse(term_bd, [2, 3])
示例#7
0
    def _FrequencyMask(self,
                       inputs,
                       global_seed,
                       dtype=tf.float32,
                       domain_id_index=0):
        """Applies frequency masking with given degree to inputs.

    Args:
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      global_seed: an integer seed tensor for stateless random ops.
      dtype: Data type.
      domain_id_index: domain id index.

    Returns:
      Inputs with random frequency masking applied.
    """
        p = self.params

        # Mask parameters.
        freq_mask_max_bins = p.freq_mask_max_bins[domain_id_index]
        multiplicity = p.freq_mask_count[domain_id_index]

        # If masking length or count is zero, do nothing.
        if freq_mask_max_bins == 0 or multiplicity == 0:
            return inputs

        # Arguments to pass to mask generator.
        batch_size, _, num_freq, _ = py_utils.GetShape(inputs)
        choose_range = tf.cast(tf.broadcast_to(num_freq, (batch_size, )),
                               dtype=tf.int32)
        # Create masks in frequency direction and apply.
        block_arrays = self._GetMask(tf.shape(inputs)[0],
                                     choose_range=choose_range,
                                     mask_size=num_freq,
                                     global_seed=global_seed,
                                     max_length=freq_mask_max_bins,
                                     masks_per_frame=0.0,
                                     multiplicity=multiplicity,
                                     dtype=dtype,
                                     max_ratio=1.0)
        outputs = tf.einsum('bxyc,by->bxyc', inputs, block_arrays)

        return outputs
示例#8
0
    def forward(labels, inputs):
        with tf.name_scope("entmax_loss"):
            assert labels.get_shape().as_list()[0] == inputs.get_shape(
            ).as_list()[0]
            p_star = entmax_support(inputs,
                                    alpha=alpha,
                                    n_iter=n_iter,
                                    ensure_sum_one=ensure_sum_one)
            loss = (1.0 - tf.math.reduce_sum(tf.math.pow(p_star, alpha),
                                             axis=-1)) / (alpha * (alpha - 1))

            p_star -= tf.cast(labels, dtype=inputs.dtype)
            loss += tf.einsum("...IJ,...IJ->...I", p_star, inputs)

        def grad_fn(d_outputs):
            with tf.name_scope("entmax_loss_grad"):
                gradient = tf.expand_dims(d_outputs, -1) * p_star
                return gradient, gradient

        return loss, grad_fn
示例#9
0
    def ProjectInputSequence(self, theta, inputs):
        """Applies input projection for the entire sequence.

    Args:
      theta: a NestedMap of layer weights. Notably, it's expected to contain
        separate weight tensors for input and hidden state projections, for
        performance reasons, under the key 'wm_i' (input) and 'wm_h'
      inputs: A NestedMap with the following fields:
        - act: A list of Tensors of shape [seqlen, batch, input_dim].

    Returns:
      A Tensor of shape [seqlen, batch, 4 * hidden_dim].
    """
        assert isinstance(inputs.act, list)
        if len(inputs.act) > 1:
            x = tf.concat(inputs.act, -1)
        else:
            x = inputs.act[0]
        # [T, B, 4 * H]
        proj_inputs = tf.einsum('TBD,DH->TBH', x, theta.wm_i)
        return proj_inputs
示例#10
0
    def FProp(self, theta, inputs):
        """Apply projection to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
        p = self.params
        with tf.name_scope(p.name):
            computation_cost.Add(
                self, 'flops',
                tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) *
                tf.cast(
                    symbolic.EvalExpr(symbolic.TENSOR_VALUES, p.input_dims *
                                      p.output_dims), tf.int64) * 2)
            use_tpu = py_utils.use_tpu()
            shape = inputs.shape
            if use_tpu and (shape is not None and shape.rank is not None
                            and shape.rank < 26):
                # Avoids reshape if feasible and uses Einsum.
                if shape.rank == 2:
                    return tf.matmul(inputs, theta.w)
                else:
                    s = ''.join([chr(x) for x in range(97, 123)])  # abc...xyz
                    r = shape.rank
                    return tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs,
                                     theta.w)

            input_dim = py_utils.GetShape(inputs)[-1]
            act = tf.matmul(tf.reshape(inputs, [-1, input_dim]), theta.w)
            output_dim = tf.shape(theta.w)[-1]
            act = tf.reshape(
                act, tf.concat([tf.shape(inputs)[:-1], [output_dim]], axis=0))
            return act
示例#11
0
  def _FrequencyMask(self,
                     inputs,
                     num_freq=80,
                     dtype=tf.float32,
                     domain_id_index=0):
    """Applies frequency masking with given degree to inputs.

    Args:
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      num_freq: Number of frequencies.
      dtype: Data type.
      domain_id_index: domain id index.

    Returns:
      Inputs with random frequency masking applied.
    """

    # If maximum mask length is zero, do nothing
    p = self.params
    if p.freq_mask_max_bins[domain_id_index] == 0:
      return inputs
    # Choose random masked length
    max_length = tf.random.uniform((tf.shape(inputs)[0],),
                                   maxval=p.freq_mask_max_bins[domain_id_index],
                                   dtype=tf.int32,
                                   seed=p.random_seed)
    # Create masks in frequency direction and apply
    block_arrays = self._GetMask(
        tf.shape(inputs)[0],
        max_length,
        choose_range=num_freq,
        mask_size=num_freq,
        dtype=dtype)
    outputs = tf.einsum('bxyc,by->bxyc', inputs, block_arrays)

    return outputs
示例#12
0
  def AttenLogitsRPEOneStep(self, query, key, abs_pos_emb):
    """RPE attention logits for one single target (query) step.

    B: batch size
    S: sequence length
    N: num of attention heads.
    H: per-head attention dimension.

    Args:
      query:          [B, N, H].
      key:         [S, B, N, H] or [S, B, N*H/128, 128].
      abs_pos_emb: [S, 1, N, H]

    Returns:
      A Tensor of shape [S, B, N]
    """
    s, b, _, _ = py_utils.GetShape(key, 4)
    _, n, h = py_utils.GetShape(query, 3)
    key = tf.reshape(key, [s, b, n, h])

    key_emb = key + abs_pos_emb
    query, key_emb = self.ToAqtActActInputs(query, key_emb)
    logits = tf.einsum('BNH,SBNH->SBN', query, key_emb)
    return self.FromAqtActActMatmul(logits)
示例#13
0
def RelPositionBias(content, abs_pos_emb):
  """Compute relative position bias.

  This is a subroutine used by variants of self-attentions with relative
  positional embedding.

  B: batch size
  T: sequence length
  N: num of attention heads.
  H: per-head attention dimension.

  output[b][n][i][j] = content[b][i][n] x abs_pos_emb[i-j+T-1][n]

  Notice padding is supposed to be masked by the caller of this function.

  Args:
    tensors of the following shapes:
    content:         [B, T, N, H]
    abs_pos_emb:     [2T - 1, N, H], the absolute positional embedding.
      abs_pos_emb[i] is the emb of relative distance i - (T-1).

  Returns:
    The attention logits tensor. [B, N, T, T]
  """
  b, t, n, h = py_utils.GetShape(content)
  l = 2 * t - 1
  abs_pos_emb = py_utils.HasShape(abs_pos_emb, [l, n, h])

  # [B, N, T, L=2T-1]
  term_bd = tf.einsum('BTNH,LNH->BNTL', content, abs_pos_emb)
  term_bd = tf.reshape(term_bd, [b, n, t * l], name='flatten')
  # [B, N, T * (L + 1)].
  term_bd = tf.pad(term_bd, ((0, 0), (0, 0), (0, t)))
  # [B, N, T, L + 1].
  term_bd = tf.reshape(term_bd, [b, n, t, l + 1], name='restore')
  return term_bd[:, :, :, t - 1::-1]
 def testSpectrumAugmenterWarpMatrixConstructor(self):
   with self.session(use_gpu=False, graph=tf.Graph()):
     inputs = tf.broadcast_to(tf.cast(tf.range(10), dtype=tf.float32), (4, 10))
     origin = tf.cast([2, 4, 4, 5], dtype=tf.float32)
     destination = tf.cast([3, 2, 6, 8], dtype=tf.float32)
     choose_range = tf.cast([4, 8, 8, 10], dtype=tf.float32)
     outputs = []
     for p in [
         spectrum_augmenter.SpectrumAugmenter.Params(),
         spectrum_augmenter_on_device.SpectrumAugmenterOnDevice.Params()
     ]:
       p.name = 'specAug_layers'
       specaug_layer = p.Instantiate()
       warp_matrix = specaug_layer._ConstructWarpMatrix(
           batch_size=4,
           matrix_size=10,
           origin=origin,
           destination=destination,
           choose_range=choose_range,
           dtype=tf.float32)
       output = tf.einsum('bij,bj->bi', warp_matrix, inputs)
       outputs.append(output)
     layer_output, layer_output_on_device = self.evaluate(outputs)
     self.assertAllClose(layer_output, layer_output_on_device)
示例#15
0
  def grad(res_grad):

    grads = tf.zeros_like(tf.einsum("ijk,ijl->ijkl", ks[0], vs[0]))

    gr_sums = sums

    q_grads = []
    k_grads = []
    v_grads = []

    for index in range(qs.shape[0] - 1, -1, -1):

      q_grads.append(
          tf.einsum("ijkl,ijl->ijk", gr_sums, res_grad[index])[None, ...])
      grads = grads + tf.einsum("ijk,ijl->ijkl", qs[index], res_grad[index])
      k_grads.append(tf.einsum("ijkl,ijl->ijk", grads, vs[index])[None, ...])
      v_grads.append(tf.einsum("ijkl,ijk->ijl", grads, ks[index])[None, ...])
      gr_sums = gr_sums - tf.einsum("ijk,ijl->ijkl", ks[index], vs[index])

    q_grads = tf.concat(q_grads[::-1], axis=0)
    k_grads = tf.concat(k_grads[::-1], axis=0)
    v_grads = tf.concat(v_grads[::-1], axis=0)

    return q_grads, k_grads, v_grads
示例#16
0
 def EinsumBBmBm(self, a, b, name=None):
     return tf.einsum('b,bm->bm', a, b, name=name)
示例#17
0
def ComputeSparseAttention(q, k, v, sparsity_indices, paddings=None):
  """Computes attention according to a sparsity pattern.

  We use the following capital letters to denote shape parameters:
    B = batch size
    S = length of the source sequence
    T = length of the target sequence
    N = number of attention heads
    H = dimensions of each attention head
    K = number of clusters
    W = attention window (K <= S)

  The 'sparsity_indices' is a tensor of integral type where the last dimension
  contains W indices (W is the attention window) for each corresponding position
  along S in 'k' that the query is allowed to attend to.

  For example, if sparsity_indices[batch_idx, target time step, head_idx] =
  [1, 7, 8], it means that token in the query attends to values with indices
  1, 7, and 8, and the attention window here is 3.

  The valid values in 'sparsity_indices' are [-1, S-1]. Note that the value -1
  is reserved to mean paddings, distinct from the value (S-1).

  For example, if W=S and 'sparsity_indices' contains range(S) on the last
  dimension, this degenerates to the original full attention.

  We require that 'sparsity_indices' does not contain duplicates (except for -1
  to indicate paddings), but we do not require 'sparsity_indices' to be sorted.

  Note that this implementation is flexible and generic but is not optimized for
  time or space complexity. Please consider grouping queries that attend to the
  same subset of values first for efficiency.

  Args:
    q: (projected) queries, [B, T, N, H];
    k: (projected) keys, [B, S, N, H];
    v: (projected) values, [B, S, N, H];
    sparsity_indices: [B, T, N, W], where W is the attention window;
    paddings: paddings for keys, [B, S] if not None.

  Returns:
    output: the encoded output, [B, T, N, H].
    atten_probs: the attention weights, [B, T, N, S].
  """
  q = tf.convert_to_tensor(q)
  k = tf.convert_to_tensor(k)
  v = tf.convert_to_tensor(v)
  sparsity_indices = tf.convert_to_tensor(sparsity_indices)

  k = py_utils.HasRank(k, 4)
  _, source_length, _, dim_per_head = py_utils.GetShape(k, 4)
  sparsity_indices = py_utils.HasRank(sparsity_indices, 4)
  batch_size, target_length, num_heads, attention_window = py_utils.GetShape(
      sparsity_indices, 4)
  py_utils.assert_less_equal(
      attention_window, source_length,
      'The provided sparsity_indices has attention window '
      ' > source length. This is likely an error.')

  # To prepare for gathering the relevant vectors from 'k', we prepare
  # gather_idx of shape [B, T, N, W, 3] where the last dimension corresponds to
  # slices in 'k' indexed by (batch index, source time step, head index),
  # where the source length index comes from the original W dimension in
  # 'sparsity_indices'.
  seq_idx = tf.expand_dims(sparsity_indices, axis=-1)
  # Overwrite the paddings -1 with valid gather indices (zeros). We will
  # fix the logits with -inf in these positions later.
  seq_idx = tf.where(seq_idx < 0, tf.zeros_like(seq_idx), seq_idx)
  batch_idx = tf.reshape(
      tf.range(0, batch_size, dtype=sparsity_indices.dtype),
      [batch_size, 1, 1, 1, 1])
  batch_idx = tf.tile(batch_idx,
                      [1, target_length, num_heads, attention_window, 1])
  head_idx = tf.reshape(
      tf.range(0, num_heads, dtype=sparsity_indices.dtype),
      [1, 1, num_heads, 1, 1])
  head_idx = tf.tile(head_idx,
                     [batch_size, target_length, 1, attention_window, 1])
  # [B, T, N, W, 3], where last dimension is (batch index, source length index,
  # head index).
  gather_idx = tf.concat([batch_idx, seq_idx, head_idx], axis=-1)

  # Both the gathered k and v have shape [B, T, N, W, H]
  k = tf.gather_nd(k, gather_idx)
  v = tf.gather_nd(v, gather_idx)

  if paddings is None:
    paddings = tf.zeros([batch_size, source_length])
  paddings = tf.convert_to_tensor(paddings)
  paddings = tf.expand_dims(paddings, axis=-1)
  # [B, S, N]
  paddings = tf.tile(paddings, [1, 1, num_heads])
  # [B, T, N, W]
  paddings = tf.gather_nd(paddings, gather_idx)

  logits = tf.einsum('BTNH, BTNWH -> BTNW', q, k)
  logits *= tf.math.rsqrt(tf.cast(dim_per_head, q.dtype))

  very_negative_logits = (
      tf.ones_like(logits) * logits.dtype.max *
      tf.constant(-0.7, dtype=logits.dtype))
  padded_logits = tf.where(
      tf.math.logical_or(sparsity_indices < 0, paddings > 0.0),
      very_negative_logits, logits)

  # [B, T, N, W]
  atten_probs = tf.nn.softmax(padded_logits, name='attention_weights')
  atten_probs = tf.where(sparsity_indices < 0, tf.zeros_like(logits),
                         atten_probs)
  output = tf.einsum('BTNW, BTNWH -> BTNH', atten_probs, v)

  # Scatter 'atten_probs' back into the original source length.
  # [B, T, N, W, 1]
  batch_idx = tf.tile(
      tf.range(batch_size)[:, None, None, None, None],
      [1, target_length, num_heads, attention_window, 1])
  # [B, T, N, W, 1]
  target_seq_idx = tf.tile(
      tf.range(target_length)[None, :, None, None, None],
      [batch_size, 1, num_heads, attention_window, 1])
  # [B, T, N, W, 1]
  head_idx = tf.tile(
      tf.range(num_heads)[None, None, :, None, None],
      [batch_size, target_length, 1, attention_window, 1])
  # seq_idx: [B, T, N, W, 1]
  # [B, T, N, W, 4]
  scatter_idx = tf.concat([batch_idx, target_seq_idx, head_idx, seq_idx], -1)
  # [B, T, N, S]
  scattered_probs = tf.scatter_nd(
      scatter_idx, atten_probs,
      [batch_size, target_length, num_heads, source_length])
  return output, scattered_probs
示例#18
0
  def FProp(self, theta, x, paddings=None, update=False):
    """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
    p = self.params
    if paddings is None:
      paddings = tf.zeros_like(x[:, :, 0, 0])
    # Shape [B, L, 1, 1]
    paddings_4d = paddings[:, :, None, None]

    if p.apply_layer_norm:
      x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

    # 'x' is normalized (but theta.means is not), we use negative dot product to
    # approximate the Euclidean distance here.
    dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means)

    # For padded positions we update the distances to very large numbers.
    very_large_dists = tf.ones_like(dists) * tf.constant(
        0.1, dtype=dists.dtype) * dists.dtype.max
    paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters])
    dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

    # Shape [B, L, N, K], the same as 'dists' above.
    nearest_one_hot = tf.one_hot(
        tf.math.argmin(dists, axis=-1),
        p.num_clusters,
        dtype=py_utils.FPropDtype(p))
    # Same shape as the input 'x'.
    nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                 theta.means)
    diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid))
    diff = py_utils.ApplyPadding(paddings_4d, diff)
    diff = tf.math.reduce_mean(diff, axis=2)

    # The commitment loss which when back proped against encourages the 'x'
    # values to commit to their chosen centroids.
    k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings)
    summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

    # TODO(zhouwk): investigate normalizing theta.means after each update.
    means_norm = tf.norm(theta.means)
    summary_utils.scalar('k_means/centroid_l2_norm/min',
                         tf.math.reduce_min(means_norm))
    summary_utils.scalar('k_means/centroid_l2_norm/mean',
                         tf.math.reduce_mean(means_norm))

    if not update:
      return dists, k_means_loss

    # To update the centroids (self.vars.means), we apply gradient descent on
    # the mini-batch of input 'x', which yields the following:
    #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
    # where x_mean is the average over all the input vectors closest to this
    # centroid.
    #
    # Note that this approach is equivalent with backprop via
    #    loss = tf.math.reduce_mean(
    #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
    # , except that here the learning rate is independently set via 'decay'.

    # Ensure that the padded positions are not used to update the centroids.
    nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

    # Sum away batch and sequence length dimensions to get per cluster count.
    # Shape: [N, K]
    per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
    summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count)

    # Sum of the input 'x' per each closest centroid.
    sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

    if py_utils.use_tpu():
      per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
      sum_x = tf.tpu.cross_replica_sum(sum_x)

    # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
    # cluster's position will always be 0, hence 'sum_x' in that dimension will
    # be 0.
    new_means = sum_x / tf.maximum(
        tf.constant(1.0, dtype=per_cluster_count.dtype),
        tf.expand_dims(per_cluster_count, axis=-1))

    # We use exponential moving average. TODO(zhouwk): investigate smooth this
    # over an exponentially moving averaged per cluster count.
    #
    # Note that we intentionally do not normalize the means after this update
    # as empirically this works better.
    update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means),
                                self.vars.means.dtype)
    return py_utils.with_dependencies(
        [tf.assign_add(self.vars.means, update_means_diff)],
        dists), k_means_loss
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ P P P P P P - - - - - - P* - - - ]   ^
        # [ P P P P P P P P P P - - P* - - - ]   | batch
        # [ P - - - - - - - - - - - P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_pad = tf.cast(
            tf.less(tf.expand_dims(pfx_time, 0),
                    tf.expand_dims(pfx_len - 1, 1)), tf.int32)
        pfx_id = pfx * pfx_pad
        pfx_last = einsum_i32(
            'BT,BT->B', pfx, tf.one_hot(pfx_len - 1,
                                        pfx_max,
                                        dtype=fprop_dtype))

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_time * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest
def einsum_i32(eq, *args):
    y = tf.einsum(eq, *[tf.cast(x, tf.int32) for x in args])
    return tf.cast(y, tf.int32)
    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state
  def _XYZFromRangeImage(self,
                         lidar_image,
                         lidar_image_mask,
                         extrinsics,
                         inclinations,
                         pixel_pose=None,
                         frame_pose=None):
    """Extract the cartesian coordinates from the range image.

    Args:
       lidar_image: [H, W, C] range image Tensor.
       lidar_image_mask: [H, W] boolean indicating which 2d coordinates in the
         lidar image are present.
       extrinsics: [4, 4] float matrix representing transformation matrix to
         world coordinates.
       inclinations: [V] beam inclinations vector.
       pixel_pose: [64, 2650, 4, 4] tensor representing per pixel pose of GBR.
       frame_pose: [4, 4] matrix representing vehicle to world transformation.

    Returns:
      [H, W, 3] range image cartesian coordinates.
    """
    height, width, channels = py_utils.GetShape(lidar_image, 3)

    conversion_dtype = tf.float32
    lidar_image = tf.cast(lidar_image, conversion_dtype)
    extrinsics = tf.cast(extrinsics, conversion_dtype)
    inclinations = tf.cast(inclinations, conversion_dtype)
    inclinations = tf.reverse(inclinations, axis=[-1])

    az_correction = py_utils.HasShape(
        tf.atan2(extrinsics[1, 0], extrinsics[0, 0]), [])
    ratios = (tf.cast(tf.range(width, 0, -1), dtype=conversion_dtype) -
              .5) / tf.cast(width, conversion_dtype)
    ratios = py_utils.HasShape(ratios, [width])

    azimuth = (ratios * 2. - 1.) * np.pi - az_correction[..., tf.newaxis]
    azimuth = py_utils.HasShape(azimuth, [width])

    lidar_image_mask = lidar_image_mask[..., tf.newaxis]
    lidar_image_mask = tf.tile(lidar_image_mask, [1, 1, channels])
    lidar_image = tf.where(lidar_image_mask, lidar_image,
                           tf.zeros_like(lidar_image))
    lidar_image_range = lidar_image[..., 0]

    azimuth = py_utils.HasShape(azimuth[tf.newaxis, ...], [1, width])
    inclinations = py_utils.HasShape(inclinations[..., tf.newaxis], [height, 1])

    cos_azimuth = tf.cos(azimuth)
    sin_azimuth = tf.sin(azimuth)
    cos_incl = tf.cos(inclinations)
    sin_incl = tf.sin(inclinations)

    x = cos_azimuth * cos_incl * lidar_image_range
    y = sin_azimuth * cos_incl * lidar_image_range
    z = sin_incl * lidar_image_range

    lidar_image_points = tf.stack([x, y, z], -1)
    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    rotation = extrinsics[0:3, 0:3]
    translation = extrinsics[0:3, 3][tf.newaxis, ...]

    # Transform the image points in cartesian coordinates to
    # the world coordinate system using the extrinsics matrix.
    #
    # We first flatten the points, apply rotation, then
    # reshape to restore the original input and then apply
    # translation.
    lidar_image_points = tf.matmul(
        tf.reshape(lidar_image_points, [-1, 3]), rotation, transpose_b=True)
    lidar_image_points = tf.reshape(lidar_image_points, [height, width, 3])
    lidar_image_points += translation

    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    # GBR uses per pixel pose.
    if pixel_pose is not None:
      pixel_pose_rotation = pixel_pose[..., 0:3, 0:3]
      pixel_pose_translation = pixel_pose[..., 0:3, 3]
      lidar_image_points = tf.einsum(
          'hwij,hwj->hwi', pixel_pose_rotation,
          lidar_image_points) + pixel_pose_translation
      if frame_pose is None:
        raise ValueError('frame_pose must be set when pixel_pose is set.')
      # To vehicle frame corresponding to the given frame_pose
      # [4, 4]
      world_to_vehicle = tf.linalg.inv(frame_pose)
      world_to_vehicle_rotation = world_to_vehicle[0:3, 0:3]
      world_to_vehicle_translation = world_to_vehicle[0:3, 3]
      # [H, W, 3]
      lidar_image_points = tf.einsum(
          'ij,hwj->hwi', world_to_vehicle_rotation,
          lidar_image_points) + world_to_vehicle_translation[tf.newaxis,
                                                             tf.newaxis, :]

    return lidar_image_points
示例#23
0
 def EinsumBxyBxBxy(self, a, b, name=None):
     return tf.einsum('bxy,bx->bxy', a, b, name=name)
示例#24
0
 def EinsumBmtBmBt(self, a, b, name=None):
     return tf.einsum('bmt,bm->bt', a, b, name=name)
示例#25
0
  def _GetMask(self,
               batch_size,
               choose_range,
               mask_size,
               max_length=None,
               masks_per_frame=0.0,
               multiplicity=1,
               dtype=tf.float32,
               max_ratio=1.0):
    """Returns fixed size multi-masks starting from random positions.

    A multi-mask is a mask obtained by applying multiple masks.

    This function when max_length is given:
      1) Sample random mask lengths less than max_length with shape
         (batch_size, multiplicity).
      2) Truncate lengths to a max of (choose_range * max_ratio),
         so that each mask is fully contained within the corresponding sequence.
      3) Random sample start points of shape (batch_size, multiplicity)
         with in (choose_range - lengths).
      4) For each batch, multiple masks (whose number is given by the
         multiplicity) are constructed.
      5) Return a mask of shape (batch_size, mask_size) where masks are
         obtained by composing the masks constructed in step 4).
         If masks_per_frame > 0, the number is given by
         min(masks_per_frame * choose_range, multiplicity).
         If not, all the masks are composed. The masked regions are set to zero.

    This function when max_length is not given:
      1) Sample random mask lengths less than (choose_range * max_ratio)
         with shape (batch_size, multiplicity).
      2) Proceed to steps 3), 4) and 5) of the above.

    Args:
      batch_size: Batch size. Integer number.
      choose_range: Range within which the masked entries must lie. Tensor of
        shape (batch_size,).
      mask_size: Size of the mask. Integer number.
      max_length: Maximum number of allowed consecutive masked entries. Integer
        number or None.
      masks_per_frame: Number of masks per frame. Float number. If > 0, the
        multiplicity of the mask is set to be masks_per_frame * choose_range.
      multiplicity: Maximum number of total masks. Integer number.
      dtype: Data type.
      max_ratio: Maximum portion of the entire range allowed to be masked. Float
        number.

    Returns:
      mask: a fixed size multi-mask starting from a random position with shape
      (batch_size, mask_size).
    """
    p = self.params
    # Non-empty random seed values are only used for testing
    # seed_1 and seed_2 are set separately to avoid correlation of
    # mask size and mask position.
    if p.random_seed:
      seed_1 = p.random_seed + 1
      seed_2 = 2 * p.random_seed
    else:
      seed_1 = p.random_seed
      seed_2 = p.random_seed
    # Sample lengths for multiple masks.
    if max_length and max_length > 0:
      max_length = tf.broadcast_to(tf.cast(max_length, dtype), (batch_size,))
    else:
      max_length = tf.cast(choose_range, dtype=dtype) * max_ratio
    masked_portion = tf.random.uniform((batch_size, multiplicity),
                                       minval=0.0,
                                       maxval=1.0,
                                       dtype=dtype,
                                       seed=seed_1)
    masked_frame_size = tf.einsum('b,bm->bm', max_length, masked_portion)
    masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32)
    # Make sure the sampled length was smaller than max_ratio * length_bound.
    # Note that sampling in this way was biased
    # (shorter sequence may over-masked.)
    choose_range = tf.expand_dims(choose_range, -1)
    choose_range = tf.tile(choose_range, [1, multiplicity])
    length_bound = tf.cast(choose_range, dtype=dtype)
    length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32)
    length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1))

    # Choose starting point.
    random_start = tf.random.uniform((batch_size, multiplicity),
                                     maxval=1.0,
                                     seed=seed_2)
    start_with_in_valid_range = random_start * tf.cast(
        (choose_range - length + 1), dtype=dtype)
    start = tf.cast(start_with_in_valid_range, tf.int32)
    end = start + length - 1

    # Shift starting and end point by small value.
    delta = tf.constant(0.1)
    start = tf.expand_dims(tf.cast(start, dtype) - delta, -1)
    start = tf.tile(start, [1, 1, mask_size])
    end = tf.expand_dims(tf.cast(end, dtype) + delta, -1)
    end = tf.tile(end, [1, 1, mask_size])

    # Construct pre-mask of shape (batch_size, multiplicity, mask_size).
    diagonal = tf.expand_dims(
        tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0)
    diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1])
    pre_mask = tf.cast(
        tf.logical_and(diagonal < end, diagonal > start), dtype=dtype)

    # Sum masks with appropriate multiplicity.
    if masks_per_frame > 0:
      multiplicity_weights = tf.tile(
          tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0),
          [batch_size, 1])
      multiplicity_tensor = masks_per_frame * tf.cast(choose_range, dtype=dtype)
      multiplicity_weights = tf.cast(
          multiplicity_weights < multiplicity_tensor, dtype=dtype)
      pre_mask = tf.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
    else:
      pre_mask = tf.reduce_sum(pre_mask, 1)
    mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype)

    if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
      mask = tf.cast(mask, p.fprop_dtype)

    return mask
示例#26
0
    def _TimeMask(self,
                  inputs,
                  seq_lengths,
                  max_ratio=1.0,
                  time_length=2560,
                  noisify=False,
                  dtype=tf.float32):
        """Applies time masking with given degree to inputs.

    Args:
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      seq_lengths: The actual sequence lengths which mask been sampled of shape
        (batch_size,).
      max_ratio: Maximum portion of the utterance allowed to be time-masked.
      time_length: Total length of time series.
      noisify: whether to noisify the masked out regions.
      dtype: Data type.

    Returns:
      Inputs with random time masking applied.
    """
        p = self.params
        # If maximum mask length is zero, do nothing
        if (p.time_mask_max_frames == 0
                and not p.use_dynamic_time_mask_max_frames):
            return inputs
        seq_lengths = tf.cast(seq_lengths, tf.int32)
        batch_size = tf.shape(inputs)[0]
        # Choose random masked length
        if p.use_dynamic_time_mask_max_frames:
            # TODO(ngyuzh): if an utterance is too short, it will never been masked.
            length_range = tf.cast(seq_lengths, dtype=tf.float32) * max_ratio
            max_length = tf.cast(
                tf.random.uniform(
                    (batch_size, ), maxval=1.0, seed=p.random_seed) *
                length_range, tf.int32)
        else:
            max_length = tf.random.uniform((batch_size, ),
                                           maxval=p.time_mask_max_frames,
                                           dtype=tf.int32,
                                           seed=p.random_seed)
        # Create masks in time direction and apply
        block_arrays = self._GetMask(batch_size,
                                     max_length,
                                     choose_range=seq_lengths,
                                     mask_size=time_length,
                                     dtype=dtype,
                                     max_ratio=max_ratio)

        outputs = tf.einsum('bxyc,bx->bxyc',
                            inputs,
                            block_arrays,
                            name='einsum_formasking')
        if noisify:
            # Sample noise with standard deviation with factor * 0.1 + 0.0001
            # TODO(ngyuzh): Make sure this won't affect EOS.
            factor = tf.random_uniform((),
                                       minval=1.0,
                                       maxval=2.0,
                                       dtype=dtype,
                                       seed=p.random_seed)
            stddev = factor * 0.1 + 0.0001
            noise = tf.random.normal([
                tf.shape(inputs)[0],
                tf.shape(inputs)[1],
                tf.shape(inputs)[2]
            ],
                                     stddev=stddev,
                                     seed=p.random_seed)
            if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
                noise = tf.cast(noise, p.fprop_dtype)
            outputs_mask = tf.einsum('bxy,bx->bxy',
                                     noise,
                                     1.0 - block_arrays,
                                     name='einsum_fornoisymasking')
            outputs = outputs + tf.expand_dims(outputs_mask, -1)
        return outputs
示例#27
0
 def EinsumBxycBxBxyc(self, a, b, name=None):
     return tf.einsum('bxyc,bx->bxyc', a, b, name=name)
示例#28
0
    def RelPositionBias(self, content, abs_pos_emb, skip_term_b=False):
        """Compute relative position bias.

    This is a subroutine used by variants of self-attentions with relative
    positional embedding.

    output[b][n][i][j] = content[b][i][n] x abs_pos_emb[i-j+T-1][n]

    Padding should be masked by the caller of this function.

    B: batch size
    T: sequence length
    N: num of attention heads.
    H: per-head attention dimension.

    Args:
      tensors of the following shapes:
      content:         [N, H] if skip_term_b else [B, T, N, H]
      abs_pos_emb:     [2T - 1, N, H], the absolute positional embedding.
        abs_pos_emb[i] is the emb of relative distance i - (T-1).
      skip_term_b:     If to skip term_b in section 3.3 equation.

    Returns:
      The attention logits tensor. [N, T, T] if skip_term_b else [B, N, T, T].
    """
        if not skip_term_b:
            b, t, n, h = py_utils.GetShape(content)
            l = 2 * t - 1
            abs_pos_emb = py_utils.HasShape(abs_pos_emb, [l, n, h])
        else:
            n, h = py_utils.GetShape(content)
            l = py_utils.GetShape(abs_pos_emb)[0]
            t = (l + 1) // 2

        if not skip_term_b:
            # [B, N, T, L=2T-1]
            content, abs_pos_emb = self.ToAqtActActInputs(content, abs_pos_emb)
            term_bd = tf.einsum('BTNH,LNH->BNTL', content, abs_pos_emb)
            term_bd = self.FromAqtActActMatmul(term_bd)

            term_bd = tf.reshape(term_bd, [b, n, t * l], name='flatten')
            # [B, N, T * (L + 1)].
            term_bd = tf.pad(term_bd, ((0, 0), (0, 0), (0, t)))
            # [B, N, T, L + 1].
            term_bd = tf.reshape(term_bd, [b, n, t, l + 1], name='restore')
            return term_bd[:, :, :, t - 1::-1]
        else:
            # [N, L=2T-1]
            content, abs_pos_emb = self.ToAqtActActInputs(content, abs_pos_emb)
            term_d = tf.einsum('NH,LNH->NL', content, abs_pos_emb)
            term_d = self.FromAqtActActMatmul(term_d)

            # [N, T, L]
            term_d = tf.tile(tf.expand_dims(term_d, axis=1), [1, t, 1],
                             name='tile')
            term_d = tf.reshape(term_d, [n, t * l])
            # [N, T * (L + 1)].
            term_d = tf.pad(term_d, ((0, 0), (0, t)))
            # [N, T, L + 1].
            term_d = tf.reshape(term_d, [n, t, l + 1], name='restore')
            return term_d[:, :, t - 1::-1]
示例#29
0
 def EinsumBxycBzxBzyc(self, a, b, name=None):
     return tf.einsum('bxyc,bzx->bzyc', a, b, name=name)
示例#30
0
 def EinsumBxycBzyBxzc(self, a, b, name=None):
     return tf.einsum('bxyc,bzy->bxzc', a, b, name=name)