Пример #1
0
 def init(self, params):
     shape = params.shape
     slots = []
     if self._factored and len(shape) >= 2:
         v_row = np.zeros(shape[:-1], dtype=np.float32)
         v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
         slots.extend([v_row, v_col])
     else:
         v = np.zeros_like(params)
         slots.append(v)
     if self._do_momentum:
         m = np.zeros_like(params)
         slots.append(m)
     return slots
Пример #2
0
def EncoderDecoderMask(x, **unused_kwargs):
    """Makes encoder-decoder mask from decoder input and a padding mask."""
    decoder_input, padding_mask = x
    padding_mask = np.reshape(
        padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1]))
    # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len].
    return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
def NewPositionalEncoding(x, positions=None, **kwargs):
  """Implements new positional encoding."""
  del kwargs
  x_length = np.shape(x)[1]
  pos = np.array(positions)[np.newaxis, :x_length, :]
  pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
  return pos
Пример #4
0
def _layer_norm_weights(input_signature):
    """Helper: create layer norm parameters."""
    features = input_signature.shape[-1]
    scale = np.ones(features)
    bias = np.zeros(features)
    weights = (scale, bias)
    return weights
Пример #5
0
def _layer_norm_params_and_state(input_shape, input_dtype, rng):
    """Helper: create layer norm parameters."""
    del input_dtype, rng
    features = input_shape[-1]
    scale = np.ones(features)
    bias = np.zeros(features)
    params = (scale, bias)
    return params, ()
Пример #6
0
def _fast_inference_init_state(input_signature, buffer_length):
    """Returns an initial state for causal attention layer fast inference."""
    def zeros_for(batch_size, shape_dtype):
        shape, dtype = shape_dtype.as_tuple()
        depth = shape[-1]
        return np.zeros((batch_size, buffer_length, depth), dtype=dtype)

    batch_size = input_signature[0].shape[0]
    k = zeros_for(batch_size, input_signature[1])
    v = zeros_for(batch_size, input_signature[2])
    mask = np.zeros((batch_size, 1, buffer_length))
    index = 0
    return (k, v, mask, index)
Пример #7
0
    def new_weights(self, input_signature):
        # Usually (B, W, H, C)
        shape = input_signature.shape
        num_channels = shape[-1]

        gamma = np.ones((num_channels, ), dtype=np.float32)
        beta = np.zeros((num_channels, ), dtype=np.float32)

        epsilon_l = base.EMPTY_WEIGHTS
        if self._learn_epsilon:
            epsilon_l = (self._init_learnt_epsilon, )

        return gamma, beta, epsilon_l
Пример #8
0
    def new_weights_and_state(self, input_signature):
        """Helper to initialize batch norm weights."""
        axis = self._axis
        axis = (axis, ) if np.isscalar(axis) else axis
        input_shape = input_signature.shape
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        beta = np.zeros(shape, dtype='float32') if self._center else ()
        gamma = np.ones(shape, dtype='float32') if self._scale else ()

        def get_stats_axis(i, d):
            if i in axis:
                return 1
            else:
                return d

        stats_shape = tuple(
            get_stats_axis(i, d) for i, d in enumerate(input_shape))
        running_mean = np.zeros(stats_shape, dtype=np.float32)
        running_var = np.ones(stats_shape, dtype=np.float32)
        n_batches = np.zeros((), dtype=np.int64)
        weights = (beta, gamma)
        state = (running_mean, running_var, n_batches)
        return weights, state
 def forward(self, inp, weights):
   """Reshape input to have heads dimension and concatenate positions there."""
   x = inp[0]
   n_batches, seqlen = x.shape[0], x.shape[1]
   d_head = x.shape[-1] // self._n_heads
   res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
   res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
   if self._n_pos == 1:  # Just one position given, tile into each head.
     pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
     pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
   else:  # As many positions as heads, concatenate them in.
     pos = [p[:, None, :, :] for p in inp[1:]]
     pos = np.concatenate(pos, axis=1)
   res = np.concatenate([res, pos], axis=-1)
   # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
   res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
   return res
Пример #10
0
    def forward_and_backward(self,
                             inputs,
                             ct,
                             state=base.EMPTY_STATE,
                             new_state=base.EMPTY_STATE,
                             rng=None,
                             **kwargs):
        del state, new_state, kwargs
        query, key, value = inputs
        depth = np.shape(query)[-1]
        do_backprop = ct is not None

        # jax uses the term cotangent (ct) to refer to gradient signals, and
        # vector-Jacobian product (vjp) for back-propagation through a layer.

        def make_mask(N, M, k):  # pylint: disable=invalid-name
            """Constructs a slice of the causal attention mask.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
            y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
            mask = jax.lax.lt((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def make_self_mask(N, M, k):  # pylint: disable=invalid-name
            """Masks out elements attending to self.

      Args:
        N: number of query positions
        M: number of key positions
        k: position of the initial query element

      Returns:
        N x M mask, where 1.0 indicates that attention is not allowed.
      """
            x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32))
            y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32))
            mask = jax.lax.eq((jax.lax.broadcast_in_dim(
                x, shape=(N, M), broadcast_dimensions=(0, )) + k),
                              jax.lax.broadcast(y, [N]))
            mask = jax.lax.convert_element_type(mask, np.float32)
            return mask

        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            if self._share_qk:
                key = self.make_unit_length(key)

            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            if self._share_qk:
                self_mask = make_self_mask(dots.shape[-2], dots.shape[-1],
                                           q_loop_idx)
                dots = dots - 1e5 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            if self._hard_k > 0:
                top_k = np.sort(dots)[...,
                                      -self._hard_k]  # Get the top-kth weight.
                top_k = jax.lax.stop_gradient(top_k)
                dots -= top_k[...,
                              np.newaxis]  # Subtract (be 0 for lower ones).
                dots = np.maximum(dots, 0)
                dots_sum = np.sum(dots, axis=-1,
                                  keepdims=True)  # Re-normalize.
                dots /= dots_sum  # Re-normalize.

            out_slice = np.matmul(dots, value)
            return out_slice

        def forward_and_vjp_slice(query_slice, q_loop_idx, key, value,
                                  ct_slice):  # pylint: disable=invalid-name
            # Capture q_loop_idx to avoid calculated gradients wrt. it.
            def forward_slice_with_q_loop_idx(query_slice, key, value):  # pylint: disable=invalid-name
                return forward_slice(query_slice, q_loop_idx, key, value)

            output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx,
                                           query_slice, key, value)
            return output_slice, vjpfun(ct_slice)

        q_loop_idx = np.zeros((), dtype=np.int32)
        q_loop_max = query.shape[-2]
        q_loop_stride = self._loop_stride
        if q_loop_max == 1:  # For abstract runs with unknown shapes.
            q_loop_stride = 1
        assert q_loop_max % q_loop_stride == 0, (
            'Stride must evenly divide the number of query elements.')

        out_accum = np.zeros_like(query)
        if do_backprop:
            query_ct_accum = np.zeros_like(query)
            key_ct_accum = np.zeros_like(key)
            value_ct_accum = np.zeros_like(value)
            init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                         value_ct_accum)
        else:
            init_vals = (q_loop_idx, out_accum)

        def cond_fun(vals):  # pylint: disable=invalid-name
            q_loop_idx = vals[0]
            return jax.lax.lt(q_loop_idx, q_loop_max)

        def body_fun(vals):  # pylint: disable=invalid-name
            """Compute a slice of the attention mechanism."""
            if do_backprop:
                (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                 value_ct_accum) = vals
            else:
                q_loop_idx, out_accum = vals

            query_slice = jax.lax.dynamic_slice_in_dim(query,
                                                       q_loop_idx,
                                                       q_loop_stride,
                                                       axis=-2)

            if do_backprop:
                ct_slice = jax.lax.dynamic_slice_in_dim(ct,
                                                        q_loop_idx,
                                                        q_loop_stride,
                                                        axis=-2)
                out_slice, partial_ct = forward_and_vjp_slice(
                    query_slice, q_loop_idx, key, value, ct_slice)
                query_ct_accum = jax.lax.dynamic_update_slice_in_dim(
                    query_ct_accum, partial_ct[0], q_loop_idx, axis=-2)
                key_ct_accum = key_ct_accum + partial_ct[1]
                value_ct_accum = value_ct_accum + partial_ct[2]
            else:
                out_slice = forward_slice(query_slice, q_loop_idx, key, value)

            out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum,
                                                            out_slice,
                                                            q_loop_idx,
                                                            axis=-2)
            q_loop_idx = q_loop_idx + q_loop_stride

            if do_backprop:
                return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum,
                        value_ct_accum)
            else:
                return (q_loop_idx, out_accum)

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if not do_backprop:
            return final_vals[1], None
        else:
            return final_vals[1], final_vals[2:]
Пример #11
0
    def batch_call_and_or_grad(self,
                               qk,
                               v,
                               ct=None,
                               return_output=True,
                               new_state=None,
                               return_state=False,
                               rng=None):
        assert return_output or ct is not None, 'No work to perform!'
        if new_state is not None and new_state is not base.EMPTY_STATE:
            buckets = new_state
        else:
            buckets = None

        # The approach here is to perform attention for one batch element and head
        # at a time. Note that there is absolutely no interaction across examples or
        # heads: this layer has no parameters, and hashing patterns are also
        # different across examples/heads. As a result, batching doesn't give any
        # performance gains except in the case of accelerator under-utilization. We
        # assume that hash-based attention will be applied primarily to long
        # sequences, where unbatched attention for a single head has sufficient
        # computation to fill up the accelerator.

        batch_loop_idx = np.zeros((), dtype=np.int32)
        batch_loop_max = qk.shape[0]

        init_vals = (batch_loop_idx, )
        if return_output:
            out_accum = np.zeros_like(qk)
            init_vals = init_vals + (out_accum, )
        if return_state:
            buckets_accum = np.zeros(
                [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32)
            init_vals = init_vals + (buckets_accum, )
        if ct is not None:
            qk_ct_accum = np.zeros_like(qk)
            v_ct_accum = np.zeros_like(v)
            init_vals = init_vals + (qk_ct_accum, v_ct_accum)

        def cond_fun(vals):
            batch_loop_idx = vals[0]
            return jax.lax.lt(batch_loop_idx, batch_loop_max)

        def body_fun(vals):
            """Performs attention for a single batch element and head."""
            batch_loop_idx = vals[0]
            if self._prng is None:
                hash_slice_rng = jax.random.fold_in(rng, batch_loop_idx)
                hash_rng, slice_rng = backend.random.split(hash_slice_rng)
            else:
                # TODO(kitaev): Maybe use the same RNG across examples (but not heads)?
                hash_rng = jax.random.fold_in(self._prng, batch_loop_idx)
                slice_rng = jax.random.fold_in(rng, batch_loop_idx)
            qk_slice = jax.lax.dynamic_index_in_dim(qk,
                                                    batch_loop_idx,
                                                    axis=0,
                                                    keepdims=False)
            v_slice = jax.lax.dynamic_index_in_dim(v,
                                                   batch_loop_idx,
                                                   axis=0,
                                                   keepdims=False)

            if buckets is None:
                buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng)
            else:
                buckets_slice = jax.lax.dynamic_index_in_dim(buckets,
                                                             batch_loop_idx,
                                                             axis=0,
                                                             keepdims=False)

            if ct is None:
                out_slice = self.single_call(qk_slice,
                                             v_slice,
                                             buckets_slice,
                                             rng=slice_rng)
            else:

                def _do_single_call(qk_slice, v_slice):
                    return self.single_call(qk_slice,
                                            v_slice,
                                            buckets_slice,
                                            rng=slice_rng)

                ct_slice = jax.lax.dynamic_index_in_dim(ct,
                                                        batch_loop_idx,
                                                        axis=0,
                                                        keepdims=False)
                out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice)
                qk_ct_slice, v_ct_slice = vjpfun(ct_slice)

            new_vals = (batch_loop_idx + 1, )
            if return_output:
                out_accum = vals[1]
                out_accum = jax.lax.dynamic_update_index_in_dim(out_accum,
                                                                out_slice,
                                                                batch_loop_idx,
                                                                axis=0)
                new_vals = new_vals + (out_accum, )
            if return_state:
                buckets_accum = vals[2]
                buckets_accum = jax.lax.dynamic_update_index_in_dim(
                    buckets_accum, buckets_slice, batch_loop_idx, axis=0)
                new_vals = new_vals + (buckets_accum, )
            if ct is not None:
                qk_ct_accum, v_ct_accum = vals[-2:]
                qk_ct_accum = jax.lax.dynamic_update_index_in_dim(
                    qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0)
                v_ct_accum = jax.lax.dynamic_update_index_in_dim(
                    v_ct_accum, v_ct_slice, batch_loop_idx, axis=0)
                new_vals = new_vals + (qk_ct_accum, v_ct_accum)

            return new_vals

        final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals)

        if return_output:
            out = final_vals[1]
        else:
            out = None

        if return_state:
            state = final_vals[2]
        else:
            state = None

        if ct is not None:
            input_ct = final_vals[-2:]
        else:
            input_ct = None

        return out, state, input_ct
Пример #12
0
 def new_weights_and_state(self, input_signature):
     qk = input_signature[0]
     state = np.zeros((qk.shape[0], self.n_hashes * qk.shape[1]),
                      dtype=np.int32)
     return self.new_weights(input_signature), state
Пример #13
0
 def new_weights(self, input_signature):
   del input_signature
   return (np.zeros((), dtype=np.float32),)
Пример #14
0
 def init(self, params):
     vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape]
     return (np.zeros_like(params), vs)
Пример #15
0
 def zeros_for(batch_size, shape_dtype):
     shape, dtype = shape_dtype.as_tuple()
     depth = shape[-1]
     return np.zeros((batch_size, buffer_length, depth), dtype=dtype)
Пример #16
0
 def dummy_loss_fn(weights):
     inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2
     output = model(inputs, weights=weights, state=state, rng=rng)
     dummy_loss = backend.numpy.sum(output[0])
     return dummy_loss
Пример #17
0
 def dummy_loss_fn(params):
   inputs = (np.zeros(input_shape[0], dtype=np.int32),) * 2
   output = model(inputs, params=params, state=state, rng=rng)
   dummy_loss = backend.numpy.sum(output[0])
   return dummy_loss
Пример #18
0
def MakeZeroState(x, depth_multiplier=1, **unused_kwargs):
    """Makes zeros of shape like x but removing the length (axis 1)."""
    assert len(x.shape) == 3, 'Expecting x of shape [batch, length, depth].'
    return np.zeros((x.shape[0], depth_multiplier * x.shape[-1]),
                    dtype=np.float32)