def _AttenLogits(query,
                 key,
                 abs_pos_emb,
                 content_bias=None,
                 positional_bias=None,
                 is_causal=False):
    """Attention logits from ...

  Transformer-XL(https://arxiv.org/pdf/1901.02860.pdf, section 3.3) version of
  self attention with relative position embedding.

  Notice padding is supposed to be masked by the caller of this function.

  B: batch size
  T: sequence length
  N: num of attention heads.
  H: per-head attention dimension.

  Args:
    tensors of the following shapes:
    query:           [B, T, N, H]
    key:             [B, T, N, H]
    abs_pos_emb:     [2T - 1, N, H]. The sinusoid positional embedding from
    https://arxiv.org/abs/1706.03762. abs_pos_emb[i] is the emb of relative
    distance i - (T-1).
    content_bias:    [N, H] or None
    positional_bias: [N, H] or None
    is_causal: A Python bool or a scalar bool Tensor. True for causal self
    attention.

  Returns:
    The attention logits tensor. [B, N, T, T]
  """
    b, t, n, h = py_utils.GetShape(query)

    key = py_utils.HasShape(key, [b, t, n, h])
    if content_bias is not None:
        content_bias = py_utils.HasShape(content_bias, [n, h])
    else:
        content_bias = 0
    if positional_bias is not None:
        positional_bias = py_utils.HasShape(positional_bias, [n, h])
    else:
        positional_bias = 0

    # [B, N, T, S=T]
    term_ac = tf.einsum('BTNH,BSNH->BNTS', query + content_bias, key)
    term_bd = RelPositionBias(query + positional_bias, abs_pos_emb, is_causal)
    return term_ac + term_bd
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
def _RelPositionBias(query, abs_pos_emb):
    """Computes relative position bias for general cases."""
    _, t, n, h = py_utils.GetShape(query)
    abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

    # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
    # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)]
    abs_pos_emb = tf.reverse(abs_pos_emb, [0])

    # [B, N, T, L=2T-1]
    term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

    # Convert to [B, N, T, T]
    # part1
    term_bd_left = term_bd[:, :, :, :t]
    term_bd_left = tf.reverse(term_bd_left, [2, 3])
    term_bd_left = RelShift(term_bd_left)
    # [B, N, T, T]
    term_bd_left = tf.reverse(term_bd_left, [2, 3])
    # part 2
    term_bd_right = term_bd[:, :, :, t - 1:]
    # [B, N, T, T]
    term_bd_right = RelShift(term_bd_right)
    # [lower triangle]
    mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0)

    # stitching togather
    return tf.where(mask > 0, term_bd_left, term_bd_right)
Exemple #4
0
    def FProp(self, theta, inputs, paddings):
        """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
                py_utils.assert_shape_match(
                    tf.shape(inputs),
                    tf.concat([
                        tf.shape(paddings),
                        [-1, symbolic.ToStatic(self.input_channels)]
                    ], 0))
            ], inputs)

            def _ApplyPadding(tensor_in, padding_in):
                padding_expanded = tf.expand_dims(
                    tf.expand_dims(padding_in, -1), -1)
                return tensor_in * (1.0 - padding_expanded)

            # Zeroing out padded inputs.
            inputs = _ApplyPadding(inputs, paddings)

            # Apply conv on 'inputs'.
            out = self._ApplyConv(theta, inputs)

            if p.partial_conv:
                out = self._RescaleBoundary(out, paddings)
            # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
            # But there's likely no real problems. Trying to set it gives an error:
            # pooling with SAME padding is not implemented for dilation_rate > 1.
            # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
            # implementation.  Consider updating it to be the actual shape.
            conv_padding = ComputeConvOutputPadding(paddings,
                                                    window=p.filter_stride[0],
                                                    stride=p.filter_stride[0])
            # Assuming padded nodes will be properly zero-ed out if necessary by
            # sub-sequent layers.
            # out = _ApplyPadding(out, conv_padding)
            out = py_utils.HasShape(
                out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
            return out, conv_padding
Exemple #5
0
    def FProp(self, theta, inputs, paddings):
        """Apply global spatial pooling to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor. It is expected to be of shape [batch,
        time]. Defaults to None, which means there no paddings.

    Returns:
      outputs, out_paddings pair.
       - outputs: has shape [batch, 1, 1, channel].
       - out_paddings: None or has shape [batch, 1].
    """
        p = self.params
        assert p.pooling_type in ['MAX', 'AVG'], p.pooling_type
        b, t, f = py_utils.GetShape(inputs, ndims=3)

        if paddings is not None:
            paddings = py_utils.HasShape(paddings, [b, t])

        if paddings is not None:
            mask = 1.0 - paddings[..., tf.newaxis, tf.newaxis]
        else:
            mask = tf.ones([b, t, 1, 1], p.dtype)
        if p.pooling_type == 'AVG':
            global_sum = tf.reduce_sum(inputs * mask,
                                       axis=[1, 2],
                                       keepdims=True)
            f = tf.cast(tf.convert_to_tensor(f), p.dtype)
            count = f * tf.reduce_sum(mask, axis=[1, 2], keepdims=True)
            out_feature = global_sum / tf.maximum(1.0, count)
        elif p.pooling_type == 'MAX':
            large_negative = (tf.ones_like(inputs) * p.dtype.max *
                              tf.constant(-0.7, dtype=p.dtype))
            padded_inputs = tf.where_v2(mask > 0.0, inputs, large_negative)
            out_feature = tf.reduce_max(padded_inputs,
                                        axis=[1, 2],
                                        keepdims=True)
        if paddings is None:
            out_paddings = None
        else:
            out_paddings = tf.reduce_min(paddings, axis=1, keepdims=True)
            out_feature *= 1.0 - out_paddings[..., tf.newaxis, tf.newaxis]
        return out_feature, out_paddings
def _RelPositionBiasCausal(query, abs_pos_emb):
    """Computes relative position bias for causal self attention."""
    _, t, n, h = py_utils.GetShape(query)

    abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

    # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
    # Retain only half and change order to [T-1, T-2, ... 0]
    # [T, N, H]
    abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t]

    # [B, N, T, L=T]
    term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

    # Perform shifting.
    term_bd = tf.reverse(term_bd, [2, 3])
    term_bd = RelShift(term_bd)
    return tf.reverse(term_bd, [2, 3])
def RelShift(x):
    """Performs relative shift on 4D tensor (first 2 axis are batching dims).

  Given input of shape [?, ?, W, W], this does "relative shifting" for the
  last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i]

  Args:
    x: A Tensor of shape [?, ?, W, W]

  Returns:
    A Tensor of the same shape as input with its content shifted (as described
    above).
  """
    b, n, w, _ = py_utils.GetShape(x)
    x = py_utils.HasShape(x, [-1, -1, w, w])
    x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1)))
    x = tf.reshape(x, [b, n, w + 1, w])
    x = x[:, :, :w, :]
    return x
    def FProp(self, theta, inputs, paddings, state0=None, segment_id=None):
        """Computes LSTM forward pass.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: A single tensor or a tuple of tensors with cardinality equal to
        rnn_cell.inputs_arity. For every input tensor, the first dimension is
        assumed to be time, second dimension batch, and third dimension depth.
      paddings: A tensor. First dim is time, second dim is batch, and third dim
        is expected to be 1.
      state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to
        the cell's zero-state.
      segment_id: A tensor to support packed inputs. First dim is time, second
        dim is batch, and third dim is expected to be 1.

    Returns:
      A tensor of [time, batch, dims].
      The final recurrent state.
    """
        p = self.params
        rcell = self.cell
        assert isinstance(rcell, (rnn_cell.RNNCell))

        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular
        # LSTM baseline.
        # Keeping slicing within the loop gives only < 3% speedup.
        cell_theta = theta.cell.copy()
        num_input_nodes = p.cell.num_input_nodes
        cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :]
        cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :]
        tf.logging.vlog(1, 'cell_theta: %r', cell_theta)
        if p.packed_input:
            assert segment_id is not None
            reset_mask = rnn_layers.GeneratePackedInputResetMask(
                segment_id, is_reverse=False)
            reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings))
        else:
            reset_mask = tf.zeros_like(paddings)

        if p.reverse:
            inputs = [tf.reverse(x, [0]) for x in inputs]
            paddings = tf.reverse(paddings, [0])
            reset_mask = tf.reverse(reset_mask, [0])

        if not state0:
            batch_size = py_utils.GetShape(paddings)[1]
            state0 = rcell.zero_state(cell_theta, batch_size)

        # [T, B, H]
        proj_inputs = rcell.ProjectInputSequence(
            cell_theta, py_utils.NestedMap(act=inputs))
        proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs,
                                         padding=paddings,
                                         reset_mask=reset_mask)

        acc_state, final_state = recurrent.Recurrent(
            theta=cell_theta,
            state0=state0,
            inputs=proj_inputs,
            cell_fn=rcell.FPropWithProjectedInput,
            cell_type=rcell.layer_type,
            accumulator_layer=self,
            allow_implicit_capture=p.allow_implicit_capture)

        act = rcell.GetOutput(acc_state)
        if p.reverse:
            act = tf.reverse(act, [0])
        return act, final_state
Exemple #9
0
    def FProp(self, theta, *args):
        """Runs p.repeat copies of self.body.FProp independently.

    Args:
      theta: Layer model parameters. The shape of each variable in theta is
        always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the
        i-th copy of self.body.
      *args: Input arguments. The shape of each tensor in args is always
        [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs
        to the i-th copy of self.body.FProp.

    Returns:
      The accumulated output_tensors. Each tensor t in the return has the shape
      [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the
      return tuple of the i-th self.body.FProp.
    """
        p = self.params
        for arg in args:
            if arg is not None:
                arg = py_utils.HasShape(arg, [p.repeat], ndims=1)

        theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars,
                                            p.repeat)
        inputs = py_utils.NestedMap(theta=theta_stack, args=list(args))
        # Infer out_shapes from FPropMeta.
        out_shapes = self._InferOutShapes(args)

        def _CellFn(unused_theta, unused_state0, inputs):
            """Recurrent cell function wrapper of body.FProp."""
            # Sets shapes for both theta and inputs to self.body.FProp.
            for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                                list(args) + theta_stack.Flatten()):
                if src is not None:
                    dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

            # Runs the actual body.FProp
            fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
            fprop_outputs = _ToTuple(fprop_outputs)
            assert len(fprop_outputs) == len(out_shapes)
            # Passes fprop outputs to the next layer through state.
            state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
            return state1, py_utils.NestedMap()

        with tf.name_scope(p.name):
            # Initiate state0 with inferred output shapes.
            state0 = py_utils.NestedMap(outputs=[
                tf.zeros(shape, args[0].dtype) for shape in out_shapes
            ])
            # Runs body.FProp p.repeat times using Recurrent.
            acc_states, _ = recurrent.Recurrent(theta=py_utils.NestedMap(),
                                                state0=state0,
                                                inputs=inputs,
                                                cell_fn=_CellFn)

            # Retrieves fprop outputs from state1 and sets shapes.
            output_tensors = tuple(acc_states.outputs)
            for out_idx in range(len(output_tensors)):
                output_tensors[out_idx].set_shape(
                    tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list()))

            return output_tensors[0] if len(args) == 1 else tuple(
                output_tensors)