コード例 #1
0
    def FProp(self, theta, inputs, paddings):
        """Applies causal pooling to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor. It is expected to be of shape [batch,
        time].

    Returns:
      outputs, out_paddings pair.
       - outputs: has the same shape as inputs.
       - out_paddings: has the same tshape as paddings.
    """

        p = self.params
        if p.left_context is None:
            raise ValueError('left_context must be set.')
        window_size = p.left_context
        left_pad_size = window_size - 1
        large_negative = p.dtype.max * tf.constant(-0.7, dtype=p.dtype)
        # For max pooling, use a large negative padding value such that the max
        # element is almost always from a non-padding position.
        pad_value = 0 if p.pooling_type == 'AVG' else large_negative
        inputs = tf.pad(inputs, [[0, 0], [left_pad_size, 0], [0, 0], [0, 0]],
                        constant_values=pad_value)

        out_feature = tf.nn.pool(inputs,
                                 window_shape=(window_size, 1),
                                 pooling_type=p.pooling_type,
                                 padding='VALID')

        if p.pooling_type == 'AVG':
            # Count the fraction of non-padding elements inside each pooling window.
            in_mask = tf.pad(1.0 - paddings, [[0, 0], [left_pad_size, 0]])

            non_padding_ratio = tf.nn.pool(in_mask[:, :, tf.newaxis],
                                           window_shape=(window_size, ),
                                           pooling_type='AVG',
                                           padding='VALID')
            # Divide by non-padding ratios to eliminate the effect of padded zeros.
            out_feature *= tf.math.reciprocal_no_nan(
                non_padding_ratio[..., tf.newaxis])
        out_feature *= 1.0 - paddings[..., tf.newaxis, tf.newaxis]
        return out_feature, paddings
コード例 #2
0
def SequenceAppendToken(x, x_paddings, token, extend=False):
    """Appends <token> to sequence `x`.

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    token: The token to append (of type integer).
    extend: Whether to extend `x` along the length dimension, this must be true
      for any sequence length in `x` that is `x_len_max` or else an invalid
      sequence will be emitted.

  Returns:
    A tuple.
      - The new sequence, Tensor of shape [batch_size, x_len_max].
      - The new paddings, Tensor of shape [batch_size, x_len_max].
  """
    batch_size = py_utils.GetShape(x)[0]
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    if extend:
        x = tf.pad(x, [[0, 0], [0, 1]])
    # Mask all invalid entries of `x` to 0.
    x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype)
    # Append the <token> based on `x_len`.
    x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1),
                       tf.cast(tf.fill([batch_size], token), x.dtype),
                       py_utils.GetShape(x))
    x_paddings = 1 - tf.sequence_mask(x_len + 1,
                                      py_utils.GetShape(x)[1],
                                      x_paddings.dtype)
    return x, x_paddings
コード例 #3
0
def ComputeConvOutputPadding(paddings,
                             window,
                             stride,
                             padding_algorithm='SAME'):
    """Computes paddings for convolution and pooling output.

  out_padding[i] == 1 iff any in_padding corresponding to that output is 1.

  Args:
    paddings: The paddings tensor. It is expected to be of shape [batch, time].
    window: The size of the windows.
    stride: The time-stride between adjacent windows.
    padding_algorithm: 'SAME' or 'VALID'.

  Returns:
    out_padding, The new padding tensor of size [batch, ceil(time / stride)].
  """
    if stride == 1:
        return paddings

    # Pad so input_length divides stride.
    input_length = py_utils.GetShape(paddings)[1]
    pad_len = (input_length + stride - 1) // stride * stride - input_length
    paddings = tf.pad(paddings, [[0, 0], [0, pad_len]], constant_values=1.0)
    out_padding = tf.nn.pool(
        tf.expand_dims(paddings, -1),
        [window],
        'MAX',
        padding=padding_algorithm,
        strides=[stride],
    )
    return tf.squeeze(out_padding, -1)
コード例 #4
0
    def _RescaleBoundary(self, out, in_paddings):
        # Rescale every output position by:
        #   (# input positions) / (# non-padding input positions)
        # where (# input posisions) = filter_size.
        p = self.params
        in_mask = 1.0 - in_paddings

        # Compute the left and right implicity padding size used in 'SAME' mode.
        filter_t = p.filter_shape[0]
        effective_filter_size = (filter_t - 1) * p.dilation_rate[0] + 1
        left_pad_size = (effective_filter_size - 1) // 2
        right_pad_size = effective_filter_size // 2

        # Compute the rescaling factor.
        # This expanded tensor has 1 on all valid positions, 0 on all padded ones,
        # which include both explicit padding provided by 'in_padding', and implicit
        # padding on boundaries.
        in_mask_padded = tf.pad(in_mask,
                                [[0, 0], [left_pad_size, right_pad_size]])
        # (# non-padding input positions) / (# input positions)
        factor_inverse = tf.nn.pool(in_mask_padded[:, :, tf.newaxis],
                                    window_shape=(filter_t, ),
                                    pooling_type='AVG',
                                    strides=(p.filter_stride[0], ),
                                    padding='VALID',
                                    dilations=(p.dilation_rate[0], ))

        factor = tf.math.reciprocal_no_nan(factor_inverse)
        return out * factor[..., tf.newaxis]
コード例 #5
0
 def _EvaluateConvKernel(self, theta, inputs):
     """Apply convolution to inputs."""
     # Same as CausalDepthwiseConv2DLayer.
     p = self.params
     assert p.filter_shape[1] == 1, 'Only 1D causal convolutions supported.'
     padding_algorithm = 'VALID'
     causal_pad_size = (p.filter_shape[0] - 1) * p.dilation_rate[0]
     inputs = tf.pad(inputs, [[0, 0], [causal_pad_size, 0], [0, 0], [0, 0]])
     filter_w = self._GetWeight(theta)
     return tf.nn.depthwise_conv2d(
         inputs,
         filter_w,
         strides=[1, p.filter_stride[0], p.filter_stride[1], 1],
         dilations=p.dilation_rate,
         data_format='NHWC',
         padding=padding_algorithm)
コード例 #6
0
def RelShift(x):
    """Performs relative shift on 4D tensor (first 2 axis are batching dims).

  Given input of shape [?, ?, W, W], this does "relative shifting" for the
  last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i]

  Args:
    x: A Tensor of shape [?, ?, W, W]

  Returns:
    A Tensor of the same shape as input with its content shifted (as described
    above).
  """
    b, n, w, _ = py_utils.GetShape(x)
    x = py_utils.HasShape(x, [-1, -1, w, w])
    x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1)))
    x = tf.reshape(x, [b, n, w + 1, w])
    x = x[:, :, :w, :]
    return x
コード例 #7
0
 def _EvaluateConvKernel(self, theta, inputs):
     """Apply convolution to inputs."""
     p = self.params
     assert p.filter_shape[1] == 1, 'Only 1D causal convolutions supported.'
     # Use VALID padding and shift the inputs to the right to ensure that the
     # first output only depends on the first input and so on. The output is
     # the same size as the input, as if the convolution used SAME padding.
     padding_algorithm = 'VALID'
     # The effective spatial filter width for dilated convolutions is
     # (kernel_width - 1) * dilation_rate + 1 as according to
     # https://www.tensorflow.org/api_docs/python/tf/nn/convolution.
     causal_pad_size = (p.filter_shape[0] - 1) * p.dilation_rate[0]
     inputs = tf.pad(inputs, [[0, 0], [causal_pad_size, 0], [0, 0], [0, 0]])
     filter_w = self._GetWeight(theta)
     return tf.nn.depthwise_conv2d(
         inputs,
         filter_w,
         strides=[1, p.filter_stride[0], p.filter_stride[1], 1],
         dilations=p.dilation_rate,
         data_format='NHWC',
         padding=padding_algorithm)
コード例 #8
0
 def PadToTargetSeqLen(tensor, constant):
   length = tf.shape(tensor)[1]
   pad = tf.maximum(0, p.beam_search.target_seq_len - length)
   return tf.pad(tensor, [[0, 0], [0, pad]], constant_values=constant)
コード例 #9
0
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)