コード例 #1
0
def SequenceAppendToken(x, x_paddings, token, extend=False):
    """Appends <token> to sequence `x`.

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    token: The token to append (of type integer).
    extend: Whether to extend `x` along the length dimension, this must be true
      for any sequence length in `x` that is `x_len_max` or else an invalid
      sequence will be emitted.

  Returns:
    A tuple.
      - The new sequence, Tensor of shape [batch_size, x_len_max].
      - The new paddings, Tensor of shape [batch_size, x_len_max].
  """
    batch_size = py_utils.GetShape(x)[0]
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    if extend:
        x = tf.pad(x, [[0, 0], [0, 1]])
    # Mask all invalid entries of `x` to 0.
    x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype)
    # Append the <token> based on `x_len`.
    x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1),
                       tf.cast(tf.fill([batch_size], token), x.dtype),
                       py_utils.GetShape(x))
    x_paddings = 1 - tf.sequence_mask(x_len + 1,
                                      py_utils.GetShape(x)[1],
                                      x_paddings.dtype)
    return x, x_paddings
コード例 #2
0
def AddMultiCurveSubplot(fig,
                         tensors,
                         paddings,
                         labels,
                         xlabels=None,
                         **kwargs):
    """Adds a multi curve subplot to Matplotlib figure.

  Plots one line for each entry in tensors and assigns a plot label legend.

  Args:
    fig: The Matplotlib figure.
    tensors: List of tensors of shape [batch, length]
    paddings: Paddings for 'tensors' with shape [batch, length] with 0. in valid
      positions and 1. in invalid.
    labels: A list of tensor names (strings) of the same length as 'tensors'.
    xlabels: A string tensor of shape [batch] with an xlabel per batch.
    **kwargs: With optional, title, xlabel, ylabel, fontsize.
  """
    data = []
    row_labels = []
    for t, l in zip(tensors, labels):
        if t is not None:
            data.append(py_utils.ApplyPadding(paddings, t))
            row_labels.append(l)
    shape = py_utils.GetShape(data[0], 2)
    data = tf.reshape(tf.concat(data, -1), [shape[0], len(data), shape[1]])

    args = [data, py_utils.LengthsFromPaddings(paddings)]
    if xlabels is not None:
        args.append(xlabels)
    fig.AddSubplot(args,
                   plot_func=_AddMultiCurveRowPlots,
                   row_labels=row_labels,
                   **kwargs)
コード例 #3
0
def ComputeConvOutputPadding(paddings, window, stride,
                             padding_algorithm='SAME'):
  """Computes paddings for convolution and pooling output.

  out_padding[i] == 1 iff any in_padding corresponding to that output is 1.

  Args:
    paddings: The paddings tensor. It is expected to be of shape [batch, time].
    window: The size of the windows.
    stride: The time-stride between adjacent windows.
    padding_algorithm: 'SAME' or 'VALID'.

  Returns:
    out_padding, The new padding tensor of size [batch, ceil(time / stride)].
  """
  if stride == 1:
    return paddings

  # Pad so input_length divides stride.
  input_length = py_utils.GetShape(paddings)[1]
  pad_len = (input_length + stride - 1) // stride * stride - input_length
  paddings = tf.pad(paddings, [[0, 0], [0, pad_len]], constant_values=1.0)
  out_padding = tf.nn.pool(
      tf.expand_dims(paddings, -1),
      [window],
      'MAX',
      padding=padding_algorithm,
      strides=[stride],
  )
  return tf.squeeze(out_padding, -1)
コード例 #4
0
def _RelPositionBias(query, abs_pos_emb):
  """Computes relative position bias for general cases."""
  _, t, n, h = py_utils.GetShape(query)
  abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

  # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
  # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)]
  abs_pos_emb = tf.reverse(abs_pos_emb, [0])

  # [B, N, T, L=2T-1]
  term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

  # Convert to [B, N, T, T]
  # part1
  term_bd_left = term_bd[:, :, :, :t]
  term_bd_left = tf.reverse(term_bd_left, [2, 3])
  term_bd_left = RelShift(term_bd_left)
  # [B, N, T, T]
  term_bd_left = tf.reverse(term_bd_left, [2, 3])
  # part 2
  term_bd_right = term_bd[:, :, :, t - 1:]
  # [B, N, T, T]
  term_bd_right = RelShift(term_bd_right)
  # [lower triangle]
  mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0)

  # stitching togather
  return tf.where(mask > 0, term_bd_left, term_bd_right)
コード例 #5
0
    def FProp(self, theta, inputs, paddings, domain_ids=None):
        """Applies data augmentation by randomly mask spectrum in inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: A tensor of shape [batch, time, freq, num_channels].
      paddings: A 0/1 tensor of shape [batch, time].
      domain_ids: input domain_ids of shape [batch, time].

    Returns:
      A pair of 2 tensors:

      - augmented_inputs: A tensor of shape [batch, time, freq, num_channels].
      - paddings: A 0/1 tensor of shape [batch, time].
    """
        p = self.params

        global_seed = None  # A tensor seed in case stateless random ops are needed.
        if p.use_input_dependent_random_seed:
            global_seed = _global_seed_from_inputs(inputs)

        batch_size, series_length, _, _ = py_utils.GetShape(inputs)
        if len(p.domain_ids) > 1:
            augmented_inputs = tf.zeros_like(inputs)
            original_inputs = inputs
            for i, domain_id in enumerate(p.domain_ids):
                augmented_domain = self._AugmentationNetwork(
                    series_length,
                    inputs,
                    paddings,
                    global_seed=global_seed,
                    domain_id_index=i)
                target_domain = tf.cast(tf.expand_dims(
                    tf.tile([domain_id], [batch_size]), -1),
                                        dtype=p.dtype)
                # [batch, time].
                domain_mask = tf.cast(tf.equal(domain_ids, target_domain),
                                      dtype=p.dtype)
                augmented_domain = self.EinsumBxycBxBxyc(
                    augmented_domain, domain_mask, name='einsum_domainmasking')
                original_inputs = self.EinsumBxycBxBxyc(
                    original_inputs,
                    1.0 - domain_mask,
                    name='einsum_domainmasking2')
                augmented_inputs = augmented_domain + augmented_inputs
            augmented_inputs = original_inputs + augmented_inputs
        else:
            augmented_inputs = self._AugmentationNetwork(
                series_length,
                inputs,
                paddings,
                global_seed=global_seed,
                domain_id_index=0)
        return augmented_inputs, paddings
コード例 #6
0
 def UnstackFeatures(self, src_inputs, src_paddings):
     """Unstacks src_input and src_paddings based off stack height."""
     sh = self.params.stack_height
     bs, old_series_length, _, channels = py_utils.GetShape(src_inputs)
     unstacked_series_length = old_series_length * sh
     src_inputs = tf.reshape(src_inputs,
                             [bs, unstacked_series_length, -1, channels])
     content = 1 - src_paddings
     lengths = tf.cast(sh * tf.reduce_sum(content, axis=1), tf.int32)
     mask = tf.sequence_mask(lengths, maxlen=unstacked_series_length)
     src_paddings = 1 - tf.cast(mask, tf.int32)
     return src_inputs, src_paddings
コード例 #7
0
    def _TimeWarp(self,
                  inputs,
                  seq_lengths,
                  global_seed,
                  dtype=tf.float32,
                  domain_id_index=0):
        """Applies time warping with given degree to inputs.

    Args:
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      seq_lengths: The actual sequence lengths which mask been sampled of shape
        (batch_size,).
      global_seed: an integer seed tensor for stateless random ops.
      dtype: Data type.
      domain_id_index: Domain ID index.

    Returns:
      Inputs with random time warping applied.
    """
        p = self.params
        batch_size, time_length, _, _ = py_utils.GetShape(inputs)

        # Get parameters for warping.
        time_warp_max_frames = p.time_warp_max_frames[domain_id_index]
        max_ratio = p.time_warp_max_ratio[domain_id_index]
        time_warp_bound = p.time_warp_bound[domain_id_index]
        assert time_warp_bound in ('static', 'dynamic')

        # If maximum warp length is zero, do nothing.
        if ((time_warp_max_frames == 0 and time_warp_bound == 'static')
                or max_ratio <= 0.0):
            return inputs
        seq_lengths = tf.cast(seq_lengths, tf.int32)

        # Discard upper-bound on time-warp frames when
        # dynamic time warping is used.
        if time_warp_bound == 'dynamic':
            time_warp_max_frames = None

        # Create warping matrix in time direction and apply
        warp_matrix = self._GetWarpMatrix(batch_size,
                                          choose_range=seq_lengths,
                                          matrix_size=time_length,
                                          global_seed=global_seed,
                                          max_warp_frames=time_warp_max_frames,
                                          dtype=dtype,
                                          max_ratio=max_ratio)

        return self.EinsumBxycBzxBzyc(inputs,
                                      warp_matrix,
                                      name='einsum_forwarping')
コード例 #8
0
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0):
    """Concats sequence `x` with sequence `y`.

  This function is length aware (based off the paddings).

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    y: A sequence of tokens of shape [batch_size, y_len_max].
    y_paddings: The paddings of `y`.
    pad: The <pad> token to fill the concatenated sequence (of type integer).

  Returns:
    A tuple.
      - Concatenation of `x` and `y` of shape
        [batch_size, x_len_max + y_len_max].
      - Paddings of the concatenation of shape
        [batch_size, x_len_max + y_len_max].
  """
    # Get the length (w/ eos).
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32)

    batch_size = py_utils.GetShape(x)[0]
    y_len_max = py_utils.GetShape(y)[1]

    # Pad `x` with necessary <pad>.
    x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1)
    # Replace all <pad> with 0.
    x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0))

    # Compute the write indices of `y` in `xy`.
    indices = tf.stack([
        tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]),
        (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) +
         tf.expand_dims(x_len, 1)),
    ], 2)

    xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x))

    # We need to remap all <pad> to `pad`.
    xy = tf.where(
        tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0),
                tf.expand_dims(x_len + y_len, 1)), xy,
        tf.fill(py_utils.GetShape(xy), pad))
    xy_paddings = 1 - tf.sequence_mask(x_len + y_len,
                                       py_utils.GetShape(xy)[1],
                                       x_paddings.dtype)
    return xy, xy_paddings
コード例 #9
0
def _AttenLogits(query,
                 key,
                 abs_pos_emb,
                 content_bias=None,
                 positional_bias=None,
                 is_causal=False):
  """Attention logits from ...

  Transformer-XL(https://arxiv.org/pdf/1901.02860.pdf, section 3.3) version of
  self attention with relative position embedding.

  Notice padding is supposed to be masked by the caller of this function.

  B: batch size
  T: sequence length
  N: num of attention heads.
  H: per-head attention dimension.

  Args:
    tensors of the following shapes:
    query:           [B, T, N, H]
    key:             [B, T, N, H]
    abs_pos_emb:     [2T - 1, N, H]. The sinusoid positional embedding from
    https://arxiv.org/abs/1706.03762. abs_pos_emb[i] is the emb of relative
    distance i - (T-1).
    content_bias:    [N, H] or None
    positional_bias: [N, H] or None
    is_causal: A Python bool or a scalar bool Tensor. True for causal self
    attention.

  Returns:
    The attention logits tensor. [B, N, T, T]
  """
  b, t, n, h = py_utils.GetShape(query)

  key = py_utils.HasShape(key, [b, t, n, h])
  if content_bias is not None:
    content_bias = py_utils.HasShape(content_bias, [n, h])
  else:
    content_bias = 0
  if positional_bias is not None:
    positional_bias = py_utils.HasShape(positional_bias, [n, h])
  else:
    positional_bias = 0

  # [B, N, T, S=T]
  term_ac = tf.einsum('BTNH,BSNH->BNTS', query + content_bias, key)
  term_bd = RelPositionBias(query + positional_bias, abs_pos_emb, is_causal)
  return term_ac + term_bd
コード例 #10
0
  def FProp(self, theta, inputs, paddings):
    """Apply global spatial pooling to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor. It is expected to be of shape [batch,
        time]. Defaults to None, which means there no paddings.

    Returns:
      outputs, out_paddings pair.
       - outputs: has shape [batch, 1, 1, channel].
       - out_paddings: None or has shape [batch, 1].
    """
    p = self.params
    assert p.pooling_type in ['MAX', 'AVG'], p.pooling_type
    b, t, f = py_utils.GetShape(inputs, ndims=3)

    if paddings is not None:
      paddings = py_utils.HasShape(paddings, [b, t])

    if paddings is not None:
      mask = 1.0 - paddings[..., tf.newaxis, tf.newaxis]
    else:
      mask = tf.ones([b, t, 1, 1], p.dtype)
    if p.pooling_type == 'AVG':
      global_sum = tf.reduce_sum(inputs * mask, axis=[1, 2], keepdims=True)
      f = tf.cast(tf.convert_to_tensor(f), p.dtype)
      count = f * tf.reduce_sum(mask, axis=[1, 2], keepdims=True)
      out_feature = global_sum / tf.maximum(1.0, count)
    elif p.pooling_type == 'MAX':
      large_negative = (
          tf.ones_like(inputs) * p.dtype.max * tf.constant(-0.7, dtype=p.dtype))
      padded_inputs = tf.where_v2(mask > 0.0, inputs, large_negative)
      out_feature = tf.reduce_max(padded_inputs, axis=[1, 2], keepdims=True)
    if paddings is None:
      out_paddings = None
    else:
      out_paddings = tf.reduce_min(paddings, axis=1, keepdims=True)
      out_feature *= 1.0 - out_paddings[..., tf.newaxis, tf.newaxis]
    return out_feature, out_paddings
コード例 #11
0
    def ComputeLoss(self, theta, predicted, input_batch):
        diff = predicted - input_batch.tgt_ids
        per_example_loss = diff * diff
        batch_dim = py_utils.GetShape(per_example_loss)[0]

        def replicate_var(name):
            return tf.convert_to_tensor([self._private_vars[name]] * batch_dim,
                                        dtype=tf.float32)

        metrics = {'loss': (tf.reduce_sum(per_example_loss), batch_dim)}
        per_example_tensors = {
            'input': input_batch.src_ids,
            'loss': per_example_loss,
            'diff': diff,
            'm': replicate_var('m'),
            'b': replicate_var('b'),
        }
        return metrics, per_example_tensors
コード例 #12
0
def _RelPositionBiasCausal(query, abs_pos_emb):
  """Computes relative position bias for causal self attention."""
  _, t, n, h = py_utils.GetShape(query)

  abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

  # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
  # Retain only half and change order to [T-1, T-2, ... 0]
  # [T, N, H]
  abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t]

  # [B, N, T, L=T]
  term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

  # Perform shifting.
  term_bd = tf.reverse(term_bd, [2, 3])
  term_bd = RelShift(term_bd)
  return tf.reverse(term_bd, [2, 3])
コード例 #13
0
def SequenceTrimLastToken(x, x_paddings):
    """Trims the last token off of sequence `x`, and set trimmed elements to 0.

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.

  Returns:
    A tuple.
      - The new sequence, Tensor of shape [batch_size, x_len_max].
      - The new paddings, Tensor of shape [batch_size, x_len_max].
  """
    x_len = tf.reduce_sum(1 - x_paddings, 1)
    x_len_max = py_utils.GetShape(x)[1]
    x_trimmed_len = tf.maximum(x_len - 1, 0)
    x_trimmed_paddings = tf.sequence_mask(x_trimmed_len, x_len_max,
                                          x_paddings.dtype)
    x_trimmed = x * tf.cast(x_trimmed_paddings, x.dtype)
    return x_trimmed, 1 - x_trimmed_paddings
コード例 #14
0
ファイル: step.py プロジェクト: shjwudp/training_results_v0.7
 def _Slice(tensor):
     """Return a slice of this tensor at time=state0.t."""
     shape = py_utils.GetShape(tensor)
     # All zeros except for t in the time dimension.
     # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...]
     begin = tf.one_hot(self.params.axis,
                        tf.rank(tensor),
                        on_value=state0.t)
     # Same as shape, but with a 1 in the time dimension.
     # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...]
     size = tf.concat([
         shape[0:self.params.axis],
         tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:]
     ],
                      axis=0)
     # Make a slice where the time dimension is fixed at state0.t.
     time_slice = tf.slice(tensor, begin, size)
     # Remove the time dimension.
     return tf.squeeze(time_slice, axis=self.params.axis)
コード例 #15
0
def RelShift(x):
  """Performs relative shift on 4D tensor (first 2 axis are batching dims).

  Given input of shape [?, ?, W, W], this does "relative shifting" for the
  last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i]

  Args:
    x: A Tensor of shape [?, ?, W, W]

  Returns:
    A Tensor of the same shape as input with its content shifted (as described
    above).
  """
  b, n, w, _ = py_utils.GetShape(x)
  x = py_utils.HasShape(x, [-1, -1, w, w])
  x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1)))
  x = tf.reshape(x, [b, n, w + 1, w])
  x = x[:, :, :w, :]
  return x
コード例 #16
0
    def _FrequencyMask(self,
                       inputs,
                       global_seed,
                       dtype=tf.float32,
                       domain_id_index=0):
        """Applies frequency masking with given degree to inputs.

    Args:
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      global_seed: an integer seed tensor for stateless random ops.
      dtype: Data type.
      domain_id_index: domain id index.

    Returns:
      Inputs with random frequency masking applied.
    """
        p = self.params

        # Mask parameters.
        freq_mask_max_bins = p.freq_mask_max_bins[domain_id_index]
        multiplicity = p.freq_mask_count[domain_id_index]

        # If masking length or count is zero, do nothing.
        if freq_mask_max_bins == 0 or multiplicity == 0:
            return inputs

        # Arguments to pass to mask generator.
        batch_size, _, num_freq, _ = py_utils.GetShape(inputs)
        choose_range = tf.cast(tf.broadcast_to(num_freq, (batch_size, )),
                               dtype=tf.int32)
        # Create masks in frequency direction and apply.
        block_arrays = self._GetMask(tf.shape(inputs)[0],
                                     choose_range=choose_range,
                                     mask_size=num_freq,
                                     global_seed=global_seed,
                                     max_length=freq_mask_max_bins,
                                     masks_per_frame=0.0,
                                     multiplicity=multiplicity,
                                     dtype=dtype,
                                     max_ratio=1.0)
        return self.EinsumBxycByBxyc(inputs, block_arrays)
コード例 #17
0
def PrepareSequenceForPlot(tensor, padding, name):
  """Prepares a sequence feature for plotting.

  The sequence feature is transposed and channels are flattened.

  Args:
    tensor: A n-D Tensor of shape [batch, time, ...].
    padding: A Tensor of shape [batch, time].
    name: A string as the name of the reshaped Tensor, which will be used as the
      subcaption for plotting.

  Returns:
    A tuple of:
      reshaped_tensor: A 3-D Tensor of shape [batch, dim, time].
      sequence_length: A 1-D Tensor of shape [batch].
  """
  # Flatten any dimensions beyond the third into the third.
  batch_size, max_len = py_utils.GetShape(tensor, 2)
  plot_tensor = tf.reshape(tensor, [batch_size, max_len, -1])
  plot_tensor = tf.transpose(plot_tensor, [0, 2, 1], name=name)
  return (plot_tensor, SequenceLength(padding))
コード例 #18
0
def ConvertToBlocks(x, block_size, padding_val=0.0):
  """Turns a sequence to non overlapping blocks.

  Args:
    x: a tensor of [batch, time, ...].
    block_size: int. Number of time frames in a block.
    padding_val: float. value on the padded frames.

  Returns:
    A tensor of [batch, num_blocks, block_size, ...], with necessary paddings,
    where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
  """
  shape = py_utils.GetShape(x)
  b, t = shape[:2]
  if block_size < 1:
    raise ValueError('block_size must be at least 1, got {}'.format(block_size))
  w = block_size
  # Pad t to be a multiply of w.
  num_blocks = (t + w - 1) // w
  pad_to_length = num_blocks * w
  padded = py_utils.PadSequenceDimension(x, pad_to_length, padding_val)
  reshaped = tf.reshape(padded, [b, num_blocks, w] + shape[2:])
  return reshaped
コード例 #19
0
  def ZeroState(self, theta, prepared_inputs, batch_size):
    """Produce a zero state for this step.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: A set of inputs pre-processed by using
        PrepareExternalInputs.
      batch_size: Number of elements in the batched input.

    Returns:
      state0, a state parameter to pass to FProp on its first invocation.
    """
    max_seq_length = py_utils.GetShape(prepared_inputs.src, 3)[0]
    atten_state = self.atten.ZeroAttentionState(max_seq_length, batch_size)
    (new_atten_context, _,
     new_atten_states) = self.atten.ComputeContextVectorWithSource(
         theta.atten,
         prepared_inputs.packed_src,
         tf.zeros([batch_size, self.params.atten.query_dim],
                  dtype=py_utils.FPropDtype(self.params)),
         attention_state=atten_state)
    return py_utils.NestedMap(
        atten_context=new_atten_context, atten_state=new_atten_states)
コード例 #20
0
  def _StringsToIdsImpl(self, strs, max_length, append_eos, languages):
    """Takes a tensor of strings and returns id/padding tensors.

    This generates `token_ids`, `target_ids`, and `paddings` in the format that
    is expected for tokenizers. This performs padding to a fixed length and
    appends the end-of-sentence token as appropriate.

    Args:
      strs: a string Tensor.
      max_length: a python integer. The second dimension of the returned arrays.
        All sequences are padded or truncated to that length.
      append_eos: a python bool. See `BaseTokenizer` for explanation.
      languages: A vector of strings with the same length as `strs`.

    Returns:
      A tuple of 3 tensors:

      - token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences
        always end with EOS unless the sequence exceeds the maximum length.
        Always padded with EOS.
      - target_ids: a tensor of sequences of WPM ids not starting with SOS
        but ending with EOS. Always padded with EOS.
      - paddings: a tensor of floats indicating, at each position, whether
        the corresponding position is padded.
    """
    p = self.params
    if append_eos is None:
      append_eos = p.append_eos

    batch_size = py_utils.GetShape(strs)[0]
    token_ids_ta = tf.TensorArray(tf.int32, batch_size)
    target_ids_ta = tf.TensorArray(tf.int32, batch_size)
    paddings_ta = tf.TensorArray(tf.float32, batch_size)

    def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta):
      """Tokenizes a single sentence."""
      ids, _ = self._wpm_encoder.Encode(strs[i])

      if append_eos:
        ids = tf.concat([ids, [self.eos_id]], axis=0)

      # This truncates after the eos is added, so some sentences might
      # not have </s> at the end.
      token_ids_ta = token_ids_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.concat([[self.sos_id], ids], axis=0), [max_length],
              self.eos_id))
      target_ids_ta = target_ids_ta.write(
          i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id))
      paddings_ta = paddings_ta.write(
          i,
          py_utils.PadOrTrimTo(
              tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.))

      return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta

    _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop(
        lambda i, *_: i < batch_size,
        _TokenizeOneSentence,
        loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta,
                   paddings_ta),
        parallel_iterations=30,
        back_prop=False)

    token_ids = token_ids_ta.stack()
    target_ids = target_ids_ta.stack()
    paddings = paddings_ta.stack()

    if not p.pad_to_max_length:
      maxlen = tf.cast(
          tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))),
          tf.int32)
      token_ids = token_ids[:, :maxlen]
      target_ids = target_ids[:, :maxlen]
      paddings = paddings[:, :maxlen]

    return token_ids, target_ids, paddings
コード例 #21
0
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
コード例 #22
0
    def _CreateCanvasAndTargets(self, batch):
        # pyformat: disable
        """Create the canvas and targets.

    Args:
      batch: A `.NestedMap`.

        - src: A `.NestedMap`.
          - ids: The source ids, ends in <eos>.
          - paddings: The source paddings.

        - tgt: A `.NestedMap`.
          - ids: The target ids, ends in <eos>.
          - paddings: The target paddings.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim].
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices (i.e., use these indices to
          tf.gather_nd the log-probs). Optional, only during training.
        - target_weights: The target weights. Optional, only during training.
    """
        # pyformat: enable
        p = self.params

        if not self.do_eval:
            # Sample our src and tgt canvas.
            src_descriptor = self._SampleCanvasAndTargets(
                batch.src.ids, batch.src.paddings)
            tgt_descriptor = self._SampleCanvasAndTargets(
                batch.tgt.ids, batch.tgt.paddings)

            # Offset the src ids (to unshare embeddings between src/tgt). Note, we
            # only offset the canvas ids, but we do not offset the vocab ids. This
            # will result in unshared embeddings, but shared softmax. This is due to
            # GPU/TPU memory limitations, empirically it is known that unsharing
            # everything results in better performance.
            vocab_size = p.decoder.softmax.num_classes
            src_descriptor.canvas = tf.where(
                tf.equal(src_descriptor.canvas_paddings, 0),
                src_descriptor.canvas + vocab_size, src_descriptor.canvas)

            # Offset the tgt indices (need shift according to src length).
            batch_size = py_utils.GetShape(batch.src.ids)[0]
            # `target_batch` is a [num_targets, batch_size] tensor where each row
            # identifies which batch the target belongs to. Note the observation that,
            # tf.reduce_sum(target_batch, 1) == 1 \forall rows.
            target_batch = tf.cast(
                tf.equal(
                    tf.expand_dims(tf.range(batch_size), 0),
                    tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)),
                tf.int32)
            src_lens = tf.cast(
                tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32)
            # `tgt_offset` is shape [num_targets] where each entry corresponds to the
            # offset needed for that target (due to the source length).
            tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1))
            # We shift the tgt slot without touching the batch or vocab.
            tgt_descriptor.target_indices += tf.concat([
                tf.zeros_like(tgt_offset), tgt_offset,
                tf.zeros_like(tgt_offset)
            ], 1)

            # The canvas is simply the sequence-level concat of the src and tgt.
            canvas, canvas_paddings = insertion.SequenceConcat(
                src_descriptor.canvas, src_descriptor.canvas_paddings,
                tgt_descriptor.canvas, tgt_descriptor.canvas_paddings)
            target_indices = tf.concat(
                [src_descriptor.target_indices, tgt_descriptor.target_indices],
                0)
            target_weights = tf.concat(
                [src_descriptor.target_weights, tgt_descriptor.target_weights],
                0)

            return py_utils.NestedMap(canvas=canvas,
                                      canvas_paddings=canvas_paddings,
                                      target_indices=target_indices,
                                      target_weights=target_weights)
コード例 #23
0
  def _OutfeedDequeueLoop(self, per_example_tensors, num_loops, num_devices):
    """Process all per-example tensor outfeed data for a TPU sess.run.

    Args:
      per_example_tensors: dict of key -> tensor as generated by TpuTrainStep.
      num_loops: number of times that TpuTrainStep will be executed by TpuTrain.
      num_devices: number of TPU cores assigned to this process.

    Returns:
      A dict of per-example tensors from the latest TpuTrainStep.
    """
    if not per_example_tensors:
      return tf.no_op()

    tensor_shapes = [
        py_utils.GetShape(per_example_tensors[key])
        for key in sorted(per_example_tensors)
    ]
    tensor_types = [
        tf.as_dtype(per_example_tensors[key].dtype)
        for key in sorted(per_example_tensors)
    ]

    def LoopBody(i, *input_arrays):
      """Process outfeed data for a single TpuTrainStep.

      Args:
        i: current loop index.
        *input_arrays: One tf.TensorArray per outfeed tensor.

      Returns:
        i+1 (new index) plus post-write tf.TensorArray handles.
      """
      # Outfeed ops execute on each JF node, so they must be located on the
      # nodes.
      outfeed_devices = []
      device_assignment = py_utils.GetTpuDeviceAssignment()
      assert device_assignment
      for replica in range(device_assignment.num_replicas):
        for core in range(device_assignment.num_cores_per_replica):
          with tf.device(device_assignment.host_device(replica, core)):
            outfeed_devices.append(
                tpu_ops.outfeed_dequeue_tuple(
                    tensor_types,
                    tensor_shapes,
                    device_ordinal=device_assignment.tpu_ordinal(replica,
                                                                 core)))
      offset = i * num_devices
      output_arrays = list(input_arrays)
      # Each output_array holds a different per-example tensor. We get results
      # for each tensor from each TPU for each TpuTrainStep call.
      for j in range(len(output_arrays)):
        for k in range(len(outfeed_devices)):
          output_arrays[j] = output_arrays[j].write(offset + k,
                                                    outfeed_devices[k][j])

      return tuple([i + 1] + output_arrays)

    def LoopCond(i, *output_arrays):
      del output_arrays
      return i < num_loops

    output_arrays = []
    for i in range(len(tensor_shapes)):
      output_arrays.append(
          tf.TensorArray(
              tensor_types[i],
              size=num_loops * num_devices,
              element_shape=tensor_shapes[i]))
    # Loop once for each time that TpuTrainStep runs.
    output_arrays = tf.while_loop(
        LoopCond, LoopBody, [0] + output_arrays, parallel_iterations=1)[1:]
    concatenated_arrays = [array.concat() for array in output_arrays]
    return dict(zip(sorted(per_example_tensors), concatenated_arrays))
コード例 #24
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` object containing: ids - The inputs tensor of
        shape [batch, time]. paddings - The ids' paddings of shape [batch,
        time].

    Returns:
      A '.NestedMap' object containing:
        encoded - The encoded features of shape [time, batch, dim] or [batch,
          time, dim], depending p.output_data_format.
        padding - The encoded features' padding of shape [time, batch] or
          [batch, time].
        segment_id - The segmentation of packed inputs of shape [time, batch] or
          [batch, time] if it is supported by the model, or None otherwise.
        embedded_inputs - The embedded inputs tokens without positional
          encodings of shape [time, batch, dim] or [batch, time, dim].
    """

        p = self.params
        with tf.name_scope(p.name):
            # [batch, time]
            input_ids = input_batch.ids
            # [batch, time]
            paddings = input_batch.paddings

            # [batch, time]
            segment_ids = input_batch.segment_ids if p.packed_input else None

            batch = py_utils.GetShape(input_ids)[0]
            time = py_utils.GetShape(input_ids)[1]

            # Embedding layer.
            # [batch, time, dim]
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                      input_ids)
            else:
                input_embs = self.softmax.EmbLookup(theta.softmax, input_ids)
            orig_input_embs = input_embs

            # [1, time, dim]
            if p.packed_input:
                positions = input_batch.segment_pos
                position_embs = tf.expand_dims(
                    self.position_emb.FPropWithPosition(
                        theta.position_emb, positions), 0)
            else:
                position_embs = tf.expand_dims(
                    self.position_emb.FProp(theta.position_emb, time), 0)

            # [batch, time, dim]
            input_embs += tf.cast(position_embs, tf.bfloat16)

            if p.input_dropout_tpl.fprop_dtype:
                input_embs = tf.cast(input_embs,
                                     p.input_dropout_tpl.fprop_dtype)
                paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype)

            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)
            # [batch, time, dim]
            transformer_input = input_embs
            # Explicitly set the input shape of Transformer layers, to avoid
            # unknown shape error occurred to tf.einsum on nonTPU devices.
            transformer_input = tf.reshape(transformer_input,
                                           [batch, time, p.model_dim])

            # Compute self-attention segment mask once.
            if p.packed_input:
                segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids, dtype=transformer_input.dtype)
            else:
                segment_mask = tf.zeros([batch, 1, time, time])

            encoded, padding = self.transformer_stack.FProp(
                theta.transformer_stack, transformer_input, paddings,
                segment_mask)

            if p.final_layer_norm:
                encoded = self.final_ln.FProp(theta.final_ln, encoded)

            seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1),
                                  tf.int32)

            if p.output_data_format == 'TBC':
                encoded = tf.transpose(encoded,
                                       [1, 0, 2])  # [time, batch, dim]
                padding = tf.transpose(padding)  # [time, batch]
                segment_ids = tf.transpose(
                    segment_ids) if p.packed_input else None
                orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2])

            return py_utils.NestedMap(
                encoded=encoded,
                padding=padding,
                seq_lengths=seq_lengths,  # used by beam_search_helper.
                segment_id=segment_ids,
                embedded_inputs=orig_input_embs)
コード例 #25
0
    def BuildDataSource(self, data_source_from_file_pattern_fn):
        """Read and return input batch from a p.file_pattern list.

    `p.file_patterns` is a list of file patterns, `p.weights` contains
    weights for each file pattern.  If provided `p.bprop_variable_filters`
    includes a bprop_variable_filter for each file pattern.

    Args:
      data_source_from_file_pattern_fn: a function that takes file_pattern as an
        argument and returns an input batch.

    Returns:
      A NestedMap containing:
        data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor
        source_selected: a tensor of size [batch_size, number of data sources]
        selected_bprop: a tensor of size [number of data sources]
        bprop_variable_filters: containing a list of bprop_variable filters for
        each source

    Raises:
      ValueError: If unknown token type.
    """
        p = self.params

        def _MakeDataSourceFromFilePatternFunc(
                data_source_from_file_pattern_fn, file_pattern):
            # It's important to invoke self._DataSourceFromFilePattern() inside the
            # lambda to make sure that the record is drawn from data source
            # only if it will be used. Weights are handled by MixByWeight, not the
            # data_source_from_file_pattern_fn.
            return lambda: data_source_from_file_pattern_fn(file_pattern)

        if len(p.weights) != len(p.file_patterns):
            raise ValueError(
                'Expected p.file_patterns and p.weights to be the same length. '
                'Found %d file_patterns, and %d weights' %
                (len(p.file_patterns), len(p.weights)))
        if not all(isinstance(x, six.string_types) for x in p.file_patterns):
            raise ValueError(
                'Expected all elements of p.file_patterns to be strings')

        # TODO(rosenberg) replace this with functools.partial
        inputs = [
            _MakeDataSourceFromFilePatternFunc(
                data_source_from_file_pattern_fn, file_pattern)
            for file_pattern in p.file_patterns
        ]
        weights = p.weights
        if not p.bprop_variable_filters:
            bprop_variable_filters = [''] * len(inputs)
        else:
            bprop_variable_filters = p.bprop_variable_filters

        data_source, selected_bprop = py_utils.MixByWeight(inputs,
                                                           weights,
                                                           seed=p.random_seed)
        # TODO(neerajgaur): Remove _bprop_onehot and change code that uses it to
        # use source_selected from input_batch.
        batch_size = py_utils.GetShape(tf.nest.flatten(data_source)[0])[0]
        ret = py_utils.NestedMap()
        ret.data = data_source
        ret.bprop_variable_filters = bprop_variable_filters
        ret.selected_bprop = selected_bprop
        ret.source_selected = tf.tile(tf.expand_dims(selected_bprop, 0),
                                      [batch_size, 1])
        return ret
コード例 #26
0
    def FProp(self, theta, inputs, paddings, state0=None, segment_id=None):
        """Computes LSTM forward pass.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: A single tensor or a tuple of tensors with cardinality equal to
        rnn_cell.inputs_arity. For every input tensor, the first dimension is
        assumed to be time, second dimension batch, and third dimension depth.
      paddings: A tensor. First dim is time, second dim is batch, and third dim
        is expected to be 1.
      state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to
        the cell's zero-state.
      segment_id: A tensor to support packed inputs. First dim is time, second
        dim is batch, and third dim is expected to be 1.

    Returns:
      A tensor of [time, batch, dims].
      The final recurrent state.
    """
        p = self.params
        rcell = self.cell
        assert isinstance(rcell, (rnn_cell.RNNCell))

        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular
        # LSTM baseline.
        # Keeping slicing within the loop gives only < 3% speedup.
        cell_theta = theta.cell.copy()
        num_input_nodes = p.cell.num_input_nodes
        cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :]
        cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :]
        tf.logging.vlog(1, 'cell_theta: %r', cell_theta)
        if p.packed_input:
            assert segment_id is not None
            reset_mask = rnn_layers.GeneratePackedInputResetMask(
                segment_id, is_reverse=False)
            reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings))
        else:
            reset_mask = tf.zeros_like(paddings)

        if p.reverse:
            inputs = [tf.reverse(x, [0]) for x in inputs]
            paddings = tf.reverse(paddings, [0])
            reset_mask = tf.reverse(reset_mask, [0])

        if not state0:
            batch_size = py_utils.GetShape(paddings)[1]
            state0 = rcell.zero_state(cell_theta, batch_size)

        # [T, B, H]
        proj_inputs = rcell.ProjectInputSequence(
            cell_theta, py_utils.NestedMap(act=inputs))
        proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs,
                                         padding=paddings,
                                         reset_mask=reset_mask)

        acc_state, final_state = recurrent.Recurrent(
            theta=cell_theta,
            state0=state0,
            inputs=proj_inputs,
            cell_fn=rcell.FPropWithProjectedInput,
            cell_type=rcell.layer_type,
            accumulator_layer=self,
            allow_implicit_capture=p.allow_implicit_capture)

        act = rcell.GetOutput(acc_state)
        if p.reverse:
            act = tf.reverse(act, [0])
        return act, final_state
コード例 #27
0
    def _BeamSearchDecodeIds(self,
                             theta,
                             encoder_outputs,
                             num_hyps_per_beam,
                             init_beam_search_state=None,
                             pre_beam_search_step_callback=None,
                             post_beam_search_step_callback=None,
                             max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap computed by encoder.
      num_hyps_per_beam: Number of hyps per beam.

      init_beam_search_state: The InitBeamSearchState callback. Please refer to
          the class header comments for more details.
      pre_beam_search_step_callback: The PreBeamSearchStepCallback callback.
          Please refer to the class header comments for more details.
      post_beam_search_step_callback: The PostBeamSearchStepCallback callback.
          Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
          self.params.target_seq_len.

    Returns:
      hyps: A tensor of shape [time, b * k] with ids of the token selected.
      prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps
        which was selected.
      done_hyps: A boolean tensor of shape [time, b * k] where value
        indicates if hyps was terminated.
      scores: A tensor of shape [time, b * k] with scores of the token
        selected.
      atten_probs: A tensor of shape [time, b * k, seq_len] which contain the
        attention probabilities over the source words against word in the
        previous hyps.
      eos_scores: A tensor of shape [time, b * k] with scores of the eos token
        selected.
      eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain
        the attention probabilities over the source words against word in the
        previous hyps.
      source_seq_lengths:  A tensor of shape [time] containing the source
        seq_lengths.
      flat_final_other_states: A array of tensors that are part of other states.
    """
        p = self.params
        source_paddings = encoder_outputs.padding

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        # We cache the NestedMap as member variable so that we can use it to
        # pack the final outputs. Tpu rewrite methods forces us to strictly pass
        # in Tensors, and output Tensors
        self._other_states = other_states

        step_ids = tf.fill([num_hyps, 1],
                           tf.constant(p.target_sos_id, dtype=tf.int32))
        min_score = -1e36
        fprop_dtype = py_utils.FPropDtype(p)
        best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) +
                       min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype)
        histories = tf.zeros(shape=[num_hyps], dtype=tf.int32)
        in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps)
        in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        # States for beam search that are inputs into Beam search step.
        accum_bs_states = [best_scores, cumulative_scores, histories]
        # States that are not accumulators.
        non_accum_bs_states = [
            in_scores,
            in_hyps,
            in_prev_hyps,
            in_done_hyps,
            in_atten_probs,
            in_eos_scores,
            in_eos_atten_probs,
        ]
        core_bs_states = tuple(accum_bs_states + non_accum_bs_states)

        flat_other_states = other_states.Flatten()

        # If there is an optimized implementation for short sequence, LoopBodyShort
        # will run first for short_seq_limit steps (after which the
        # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the
        # default implementation) is used to continue the rest of the steps. For
        # decoders which do not have the short sequence specific implementation,
        # only the LoopBodyLong (the default implementation) will run.

        if p.short_seq_limit > 0:

            def LoopContinueShort(cur_step, all_done, unused_step_ids,
                                  unused_core_bs_states,
                                  unused_other_states_list):
                """Use short_seq optimization when cur_step is smaller than limit."""
                return tf.math.logical_and(cur_step < p.short_seq_limit,
                                           tf.math.logical_not(all_done))

            def LoopBodyShort(cur_step, unused_all_done, step_ids,
                              core_bs_states, other_states_list):
                """Loop body of short_seq optimization.

        Instead of doing computation for the entire padded sequence, while loop
        with early exit is used within each _BeamSearchStep to do computation
        for only the actual sequence (seq_length <= cur_step).
        use_short_seq_opt is used as the flag to pass this information down to
        the decoder implementation.

        Args:
          cur_step: A scalar int tensor, the current time step, 0-based.
          unused_all_done: A tf.bool, indicating whether the decoding finishes.
          step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the
            current search step.
          core_bs_states: A tuple of core beam search states.
          other_states_list: A flattened NestedMap of other beam search states.

        Returns:
          The updated input tuple, with the same shape.
        """
                (cur_step, all_done, new_step_ids, new_bs_states,
                 new_other_states) = self._BeamSearchStep(
                     theta,
                     encoder_outputs,
                     cur_step,
                     step_ids,
                     core_bs_states,
                     other_states.Pack(other_states_list),
                     num_hyps_per_beam,
                     pre_beam_search_step_callback,
                     post_beam_search_step_callback,
                     use_short_seq_opt=True)
                return (cur_step, all_done, new_step_ids, new_bs_states,
                        new_other_states.Flatten())

            (cur_step, all_done, step_ids, core_bs_states,
             flat_other_states) = tf.while_loop(
                 LoopContinueShort,
                 LoopBodyShort,
                 loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                            flat_other_states),
                 parallel_iterations=10,
                 back_prop=False,
                 swap_memory=False,
                 shape_invariants=(
                     tf.TensorShape(cur_step.get_shape()),
                     tf.TensorShape(all_done.get_shape()),
                     tf.TensorShape(step_ids.get_shape()),
                     tuple(
                         list(_GetShapes(accum_bs_states)) +
                         list(_GetShapes(non_accum_bs_states,
                                         none_shapes=True))),
                     _GetShapes(flat_other_states, none_shapes=True)),
                 maximum_iterations=max_steps)

        def LoopContinueLong(cur_step, all_done, unused_step_ids,
                             unused_core_bs_states, unused_other_states_list):
            """Continue default implementation until decoding finishes."""
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states,
                         other_states_list):
            """Loop body of default long_seq implementation."""
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta,
                 encoder_outputs,
                 cur_step,
                 step_ids,
                 core_bs_states,
                 other_states.Pack(other_states_list),
                 num_hyps_per_beam,
                 pre_beam_search_step_callback,
                 post_beam_search_step_callback,
                 use_short_seq_opt=False)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinueLong,
            LoopBodyLong,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(
                tf.TensorShape(cur_step.get_shape()),
                tf.TensorShape(all_done.get_shape()),
                tf.TensorShape(step_ids.get_shape()),
                tuple(
                    list(_GetShapes(accum_bs_states)) +
                    list(_GetShapes(non_accum_bs_states, none_shapes=True))),
                _GetShapes(flat_other_states, none_shapes=False)),
            maximum_iterations=max_steps)

        if isinstance(source_paddings, py_utils.NestedMap):
            source_seq_lengths = tf.cast(tf.round(
                tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]),
                              1)),
                                         dtype=tf.int32)
        else:
            source_seq_lengths = tf.cast(tf.round(
                tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                                         dtype=tf.int32)

        # Concatenate all outputs on axis=0.
        scores = final_bs_states[3].stack()
        hyps = final_bs_states[4].stack()
        prev_hyps = final_bs_states[5].stack()
        done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool)
        atten_probs = final_bs_states[7].stack()
        eos_scores = final_bs_states[8].stack()
        eos_atten_probs = final_bs_states[9].stack()
        rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores,
                eos_atten_probs, source_seq_lengths)

        # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after
        # b/111131551 is resolved.
        # Canonical shapes for tensors of various. ranks
        r_shapes = [
            py_utils.GetShape(source_seq_lengths),
            py_utils.GetShape(hyps),
            py_utils.GetShape(atten_probs)
        ]
        # Reshape all tensors to [-1] to avoid cost of copy due to padding.
        rets_r1 = [tf.reshape(r, [-1]) for r in rets]

        return tuple(r_shapes) + tuple(rets_r1) + tuple(
            flat_final_other_states)
コード例 #28
0
    def _TimeMask(self,
                  inputs,
                  seq_lengths,
                  global_seed,
                  noisify=False,
                  gaussian_noise=False,
                  dtype=tf.float32,
                  domain_id_index=0):
        """Applies time masking with given degree to inputs.

    Args:
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      seq_lengths: The actual sequence lengths which mask been sampled of shape
        (batch_size,).
      global_seed: an integer seed tensor for stateless random ops.
      noisify: Whether to noisify the masked out regions.
      gaussian_noise: Whether to use gaussian noise when noisifying.
      dtype: Data type.
      domain_id_index: domain id index.

    Returns:
      Inputs with random time masking applied.
    """
        p = self.params

        # Get time masking parameters.
        time_mask_max_frames = p.time_mask_max_frames[domain_id_index]
        time_masks_per_frame = p.time_masks_per_frame[domain_id_index]
        use_dynamic_time_mask_max_frames = \
            p.use_dynamic_time_mask_max_frames[domain_id_index]
        multiplicity = p.time_mask_count[domain_id_index]
        max_ratio = p.time_mask_max_ratio[domain_id_index]

        # If maximum mask length is zero, do nothing.
        if ((time_mask_max_frames == 0
             and not use_dynamic_time_mask_max_frames) or max_ratio <= 0.0):
            return inputs
        if multiplicity == 0:
            return inputs
        seq_lengths = tf.cast(seq_lengths, tf.int32)
        batch_size, time_length, _, _ = py_utils.GetShape(inputs)

        # When using dynamic time mask size, discard upper-bound on
        # maximum allowed frames for time mask.
        if use_dynamic_time_mask_max_frames:
            time_mask_max_frames = None
        # Create masks in time direction and apply.
        block_arrays = self._GetMask(batch_size,
                                     choose_range=seq_lengths,
                                     mask_size=time_length,
                                     global_seed=global_seed,
                                     max_length=time_mask_max_frames,
                                     masks_per_frame=time_masks_per_frame,
                                     multiplicity=multiplicity,
                                     dtype=dtype,
                                     max_ratio=max_ratio)

        # Non-empty random seed values are only used for testing or when using
        # stateless random ops. seed_6 and seed_7 are set separately to avoid
        # correlation of warp magnitude and origin position.
        if p.use_input_dependent_random_seed:
            seed_6 = global_seed + 6
            seed_7 = global_seed + 7
        else:
            seed_6 = p.random_seed
            seed_7 = p.random_seed

        outputs = self.EinsumBxycBxBxyc(inputs,
                                        block_arrays,
                                        name='einsum_formasking')
        if noisify:
            # Sample noise with standard deviation with factor * 0.1 + 0.0001
            # TODO(ngyuzh): Make sure this won't affect EOS.
            if gaussian_noise:
                stddev = 1.0
            else:
                random_uniform = _random_uniform_op(
                    p.use_input_dependent_random_seed)
                factor = random_uniform(shape=(),
                                        minval=1.0,
                                        maxval=2.0,
                                        dtype=dtype,
                                        seed=seed_6)
                stddev = factor * 0.1 + 0.0001
            random_normal = _random_normal_op(
                p.use_input_dependent_random_seed)
            noise = random_normal(shape=[
                tf.shape(inputs)[0],
                tf.shape(inputs)[1],
                tf.shape(inputs)[2]
            ],
                                  stddev=stddev,
                                  seed=seed_7)
            if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
                noise = tf.cast(noise, p.fprop_dtype)
            outputs_mask = self.EinsumBxyBxBxy(noise,
                                               1.0 - block_arrays,
                                               name='einsum_fornoisymasking')
            outputs = outputs + tf.expand_dims(outputs_mask, -1)

        return outputs