Esempio n. 1
0
def _AttenLogits(query,
                 key,
                 abs_pos_emb,
                 content_bias=None,
                 positional_bias=None,
                 is_causal=False):
  """Attention logits from ...

  Transformer-XL(https://arxiv.org/pdf/1901.02860.pdf, section 3.3) version of
  self attention with relative position embedding.

  Notice padding is supposed to be masked by the caller of this function.

  B: batch size
  T: sequence length
  N: num of attention heads.
  H: per-head attention dimension.

  Args:
    tensors of the following shapes:
    query:           [B, T, N, H]
    key:             [B, T, N, H]
    abs_pos_emb:     [2T - 1, N, H]. The sinusoid positional embedding from
    https://arxiv.org/abs/1706.03762. abs_pos_emb[i] is the emb of relative
    distance i - (T-1).
    content_bias:    [N, H] or None
    positional_bias: [N, H] or None
    is_causal: A Python bool or a scalar bool Tensor. True for causal self
    attention.

  Returns:
    The attention logits tensor. [B, N, T, T]
  """
  b, t, n, h = py_utils.GetShape(query)

  key = py_utils.HasShape(key, [b, t, n, h])
  if content_bias is not None:
    content_bias = py_utils.HasShape(content_bias, [n, h])
  else:
    content_bias = 0
  if positional_bias is not None:
    positional_bias = py_utils.HasShape(positional_bias, [n, h])
  else:
    positional_bias = 0

  # [B, N, T, S=T]
  term_ac = tf.einsum('BTNH,BSNH->BNTS', query + content_bias, key)
  term_bd = RelPositionBias(query + positional_bias, abs_pos_emb, is_causal)
  return term_ac + term_bd
Esempio n. 2
0
  def FProp(self, theta, *args):
    """Runs p.repeat copies of self.body.FProp independently.

    Args:
      theta: Layer model parameters. The shape of each variable in theta is
        always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the
        i-th copy of self.body.
      *args: Input arguments. The shape of each tensor in args is always
        [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs
        to the i-th copy of self.body.FProp.

    Returns:
      The accumulated output_tensors. Each tensor t in the return has the shape
      [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the
      return tuple of the i-th self.body.FProp.
    """
    p = self.params
    for arg in args:
      if arg is not None:
        arg = py_utils.HasShape(arg, [p.repeat], ndims=1)

    theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat)
    inputs = py_utils.NestedMap(theta=theta_stack, args=list(args))
    # Infer out_shapes from FPropMeta.
    out_shapes = self._InferOutShapes(args)

    def _CellFn(unused_theta, unused_state0, inputs):
      """Recurrent cell function wrapper of body.FProp."""
      # Sets shapes for both theta and inputs to self.body.FProp.
      for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                          list(args) + theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
      fprop_outputs = _ToTuple(fprop_outputs)
      assert len(fprop_outputs) == len(out_shapes)
      # Passes fprop outputs to the next layer through state.
      state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
      return state1, py_utils.NestedMap()

    with tf.name_scope(p.name):
      # Initiate state0 with inferred output shapes.
      state0 = py_utils.NestedMap(
          outputs=[tf.zeros(shape, args[0].dtype) for shape in out_shapes])
      # Runs body.FProp p.repeat times using Recurrent.
      acc_states, _ = recurrent.Recurrent(
          theta=py_utils.NestedMap(),
          state0=state0,
          inputs=inputs,
          cell_fn=_CellFn)

      # Retrieves fprop outputs from state1 and sets shapes.
      output_tensors = tuple(acc_states.outputs)
      for out_idx in range(len(output_tensors)):
        output_tensors[out_idx].set_shape(
            tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list()))

      return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
Esempio n. 3
0
def _RelPositionBias(query, abs_pos_emb):
  """Computes relative position bias for general cases."""
  _, t, n, h = py_utils.GetShape(query)
  abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

  # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
  # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)]
  abs_pos_emb = tf.reverse(abs_pos_emb, [0])

  # [B, N, T, L=2T-1]
  term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

  # Convert to [B, N, T, T]
  # part1
  term_bd_left = term_bd[:, :, :, :t]
  term_bd_left = tf.reverse(term_bd_left, [2, 3])
  term_bd_left = RelShift(term_bd_left)
  # [B, N, T, T]
  term_bd_left = tf.reverse(term_bd_left, [2, 3])
  # part 2
  term_bd_right = term_bd[:, :, :, t - 1:]
  # [B, N, T, T]
  term_bd_right = RelShift(term_bd_right)
  # [lower triangle]
  mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0)

  # stitching togather
  return tf.where(mask > 0, term_bd_left, term_bd_right)
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
Esempio n. 5
0
  def FProp(self, theta, inputs, paddings):
    """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
    p = self.params
    with tf.name_scope(p.name):
      inputs = py_utils.with_dependencies([
          py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
          py_utils.assert_shape_match(
              tf.shape(inputs),
              tf.concat([
                  tf.shape(paddings),
                  [-1, symbolic.ToStatic(self.input_channels)]
              ], 0))
      ], inputs)

      def _ApplyPadding(tensor_in, padding_in):
        padding_expanded = tf.expand_dims(tf.expand_dims(padding_in, -1), -1)
        return tensor_in * (1.0 - padding_expanded)

      # Zeroing out padded inputs.
      inputs = _ApplyPadding(inputs, paddings)

      # Apply conv on 'inputs'.
      out = self._ApplyConv(theta, inputs)

      if p.partial_conv:
        out = self._RescaleBoundary(out, paddings)
      # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
      # But there's likely no real problems. Trying to set it gives an error:
      # pooling with SAME padding is not implemented for dilation_rate > 1.
      # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
      # implementation.  Consider updating it to be the actual shape.
      conv_padding = ComputeConvOutputPadding(
          paddings, window=p.filter_stride[0], stride=p.filter_stride[0])
      # Assuming padded nodes will be properly zero-ed out if necessary by
      # sub-sequent layers.
      # out = _ApplyPadding(out, conv_padding)
      out = py_utils.HasShape(
          out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
      return out, conv_padding
Esempio n. 6
0
  def FProp(self, theta, inputs, paddings):
    """Apply global spatial pooling to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor. It is expected to be of shape [batch,
        time]. Defaults to None, which means there no paddings.

    Returns:
      outputs, out_paddings pair.
       - outputs: has shape [batch, 1, 1, channel].
       - out_paddings: None or has shape [batch, 1].
    """
    p = self.params
    assert p.pooling_type in ['MAX', 'AVG'], p.pooling_type
    b, t, f = py_utils.GetShape(inputs, ndims=3)

    if paddings is not None:
      paddings = py_utils.HasShape(paddings, [b, t])

    if paddings is not None:
      mask = 1.0 - paddings[..., tf.newaxis, tf.newaxis]
    else:
      mask = tf.ones([b, t, 1, 1], p.dtype)
    if p.pooling_type == 'AVG':
      global_sum = tf.reduce_sum(inputs * mask, axis=[1, 2], keepdims=True)
      f = tf.cast(tf.convert_to_tensor(f), p.dtype)
      count = f * tf.reduce_sum(mask, axis=[1, 2], keepdims=True)
      out_feature = global_sum / tf.maximum(1.0, count)
    elif p.pooling_type == 'MAX':
      large_negative = (
          tf.ones_like(inputs) * p.dtype.max * tf.constant(-0.7, dtype=p.dtype))
      padded_inputs = tf.where_v2(mask > 0.0, inputs, large_negative)
      out_feature = tf.reduce_max(padded_inputs, axis=[1, 2], keepdims=True)
    if paddings is None:
      out_paddings = None
    else:
      out_paddings = tf.reduce_min(paddings, axis=1, keepdims=True)
      out_feature *= 1.0 - out_paddings[..., tf.newaxis, tf.newaxis]
    return out_feature, out_paddings
Esempio n. 7
0
def _RelPositionBiasCausal(query, abs_pos_emb):
  """Computes relative position bias for causal self attention."""
  _, t, n, h = py_utils.GetShape(query)

  abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

  # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
  # Retain only half and change order to [T-1, T-2, ... 0]
  # [T, N, H]
  abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t]

  # [B, N, T, L=T]
  term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

  # Perform shifting.
  term_bd = tf.reverse(term_bd, [2, 3])
  term_bd = RelShift(term_bd)
  return tf.reverse(term_bd, [2, 3])
Esempio n. 8
0
def RelShift(x):
  """Performs relative shift on 4D tensor (first 2 axis are batching dims).

  Given input of shape [?, ?, W, W], this does "relative shifting" for the
  last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i]

  Args:
    x: A Tensor of shape [?, ?, W, W]

  Returns:
    A Tensor of the same shape as input with its content shifted (as described
    above).
  """
  b, n, w, _ = py_utils.GetShape(x)
  x = py_utils.HasShape(x, [-1, -1, w, w])
  x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1)))
  x = tf.reshape(x, [b, n, w + 1, w])
  x = x[:, :, :w, :]
  return x
Esempio n. 9
0
    def FProp(self, theta, inputs, paddings, state0=None, segment_id=None):
        """Computes LSTM forward pass.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: A single tensor or a tuple of tensors with cardinality equal to
        rnn_cell.inputs_arity. For every input tensor, the first dimension is
        assumed to be time, second dimension batch, and third dimension depth.
      paddings: A tensor. First dim is time, second dim is batch, and third dim
        is expected to be 1.
      state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to
        the cell's zero-state.
      segment_id: A tensor to support packed inputs. First dim is time, second
        dim is batch, and third dim is expected to be 1.

    Returns:
      A tensor of [time, batch, dims].
      The final recurrent state.
    """
        p = self.params
        rcell = self.cell
        assert isinstance(rcell, (rnn_cell.RNNCell))

        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular
        # LSTM baseline.
        # Keeping slicing within the loop gives only < 3% speedup.
        cell_theta = theta.cell.copy()
        num_input_nodes = p.cell.num_input_nodes
        cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :]
        cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :]
        tf.logging.vlog(1, 'cell_theta: %r', cell_theta)
        if p.packed_input:
            assert segment_id is not None
            reset_mask = rnn_layers.GeneratePackedInputResetMask(
                segment_id, is_reverse=False)
            reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings))
        else:
            reset_mask = tf.zeros_like(paddings)

        if p.reverse:
            inputs = [tf.reverse(x, [0]) for x in inputs]
            paddings = tf.reverse(paddings, [0])
            reset_mask = tf.reverse(reset_mask, [0])

        if not state0:
            batch_size = py_utils.GetShape(paddings)[1]
            state0 = rcell.zero_state(cell_theta, batch_size)

        # [T, B, H]
        proj_inputs = rcell.ProjectInputSequence(
            cell_theta, py_utils.NestedMap(act=inputs))
        proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs,
                                         padding=paddings,
                                         reset_mask=reset_mask)

        acc_state, final_state = recurrent.Recurrent(
            theta=cell_theta,
            state0=state0,
            inputs=proj_inputs,
            cell_fn=rcell.FPropWithProjectedInput,
            cell_type=rcell.layer_type,
            accumulator_layer=self,
            allow_implicit_capture=p.allow_implicit_capture)

        act = rcell.GetOutput(acc_state)
        if p.reverse:
            act = tf.reverse(act, [0])
        return act, final_state