Пример #1
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        learning_rate = opt_params["learning_rate"]
        beta1 = opt_params["beta1"]
        decay_rate = opt_params["decay_rate"]
        clipping_threshold = opt_params["clipping_threshold"]
        weight_decay_rate = opt_params["weight_decay_rate"]
        epsilon1 = opt_params["epsilon1"]
        epsilon2 = opt_params["epsilon2"]
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_params.astype(params.dtype), updates
Пример #2
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Пример #3
0
 def update(self, step, grads, weights, avg_sq_grad, opt_params):
   del step
   learning_rate = opt_params['learning_rate']
   gamma = opt_params['gamma']
   eps = opt_params['eps']
   avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma)
   weights = weights - (learning_rate * grads /
                        (np.sqrt(avg_sq_grad) + eps)).astype(weights.dtype)
   return weights, avg_sq_grad
Пример #4
0
 def update(self, step, grads, params, avg_sq_grad, opt_params):
     del step
     learning_rate = opt_params["learning_rate"]
     gamma = opt_params["gamma"]
     eps = opt_params["eps"]
     avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma)
     params = params - (learning_rate * grads /
                        (np.sqrt(avg_sq_grad) + eps)).astype(params.dtype)
     return params, avg_sq_grad
Пример #5
0
 def _update_diagonal(self, grads, params, m, v, opt_params):
     learning_rate = opt_params['learning_rate']
     momentum = opt_params['momentum']
     v[0] += grads * grads
     preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                               np.zeros_like(v[0]))
     preconditioned_grads = preconditioner * grads
     m = (1 - momentum) * preconditioned_grads + momentum * m
     params = params - (learning_rate * m).astype(params.dtype)
     return params, (m, v)
Пример #6
0
 def Init(shape, rng):
     """Returns random values for initializing weights of the given `shape`."""
     fan_in, fan_out = _GetFans(shape, out_dim, in_dim)
     gain = scale
     if mode == 'fan_in':
         gain /= fan_in
     elif mode == 'fan_out':
         gain /= fan_out
     elif mode == 'fan_avg':
         gain /= (fan_in + fan_out) / 2
     if distribution == 'truncated_normal':
         # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
         stddev = np.sqrt(gain) / .87962566103423978
         new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev
         return new_weights.astype('float32')
     elif distribution == 'normal':
         new_weights = random.normal(rng, shape) * np.sqrt(gain)
         return new_weights.astype('float32')
     elif distribution == 'uniform':
         lim = np.sqrt(3. * gain)
         return random.uniform(rng, shape, np.float32, -lim, lim)
     else:
         raise ValueError('invalid distribution for ScaleInitializer')
Пример #7
0
 def update(self, step, grads, weights, slots, opt_params):
   m, v = slots
   learning_rate = opt_params['learning_rate']
   weight_decay_rate = opt_params['weight_decay_rate']
   b1 = opt_params['b1']
   b2 = opt_params['b2']
   eps = opt_params['eps']
   m = (1 - b1) * grads + b1 * m  # First  moment estimate.
   v = (1 - b2) * (grads ** 2) + b2 * v  # Second moment estimate.
   mhat = m / (1 - b1 ** (step + 1))  # Bias correction.
   vhat = v / (1 - b2 ** (step + 1))
   weights = (1 - weight_decay_rate) * weights - (
       learning_rate * mhat / (np.sqrt(vhat) + eps)).astype(weights.dtype)
   return weights, (m, v)
Пример #8
0
 def update(self, step, grads, params, slots, opt_params):
     m, v = slots
     learning_rate = opt_params["learning_rate"]
     weight_decay_rate = opt_params["weight_decay_rate"]
     b1 = opt_params["b1"]
     b2 = opt_params["b2"]
     eps = opt_params["eps"]
     m = (1 - b1) * grads + b1 * m  # First  moment estimate.
     v = (1 - b2) * (grads**2) + b2 * v  # Second moment estimate.
     mhat = m / (1 - b1**(step + 1))  # Bias correction.
     vhat = v / (1 - b2**(step + 1))
     params = (1 - weight_decay_rate) * params - (
         learning_rate * mhat / (np.sqrt(vhat) + eps)).astype(params.dtype)
     return params, (m, v)
 def learning_rate(step):  # pylint: disable=invalid-name
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= np.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= np.sqrt(warmup_steps)
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = np.maximum(0.0, (step - warmup_steps) /
                                   float(steps_per_cycle))
             ret *= np.maximum(
                 0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     ret = np.asarray(ret, dtype=np.float32)
     return {'learning_rate': ret}
Пример #10
0
    def forward(self, inputs, weights):
        gamma, beta, epsilon_l = weights

        epsilon = self._init_epsilon
        if epsilon_l is not base.EMPTY_WEIGHTS:
            epsilon += np.abs(epsilon_l[0])

        # Omit B and C
        axis = tuple(range(1, len(np.shape(inputs)) - 1))
        # (B, 1, 1, C)
        nu2 = np.mean(inputs**2, axis=axis, keepdims=True)
        # (B, W, H, C)
        xhat = inputs / np.sqrt(nu2 + epsilon)

        return gamma * xhat + beta
Пример #11
0
 def learning_rate(step):  # pylint: disable=invalid-name
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == "constant":
             ret *= constant
         elif name == "linear_warmup":
             ret *= np.minimum(1.0, step / warmup_steps)
         elif name == "rsqrt_decay":
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == "decay_every":
             ret *= (decay_factor**(step // steps_per_decay))
         else:
             raise ValueError("Unknown factor %s." % name)
     ret = np.asarray(ret, dtype=np.float32)
     return {"learning_rate": ret}
Пример #12
0
        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            if self._share_qk:
                key = self.make_unit_length(key)

            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            if self._share_qk:
                self_mask = make_self_mask(dots.shape[-2], dots.shape[-1],
                                           q_loop_idx)
                dots = dots - 1e5 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            if self._hard_k > 0:
                top_k = np.sort(dots)[...,
                                      -self._hard_k]  # Get the top-kth weight.
                top_k = jax.lax.stop_gradient(top_k)
                dots -= top_k[...,
                              np.newaxis]  # Subtract (be 0 for lower ones).
                dots = np.maximum(dots, 0)
                dots_sum = np.sum(dots, axis=-1,
                                  keepdims=True)  # Re-normalize.
                dots /= dots_sum  # Re-normalize.

            out_slice = np.matmul(dots, value)
            return out_slice
 def _update_sketched(self, grads, params, m, v, opt_params):
   """Update for higher-rank parameters."""
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   shape = params.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   params = params - (learning_rate * m).astype(params.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return params, (m, v)
Пример #14
0
 def test_batch_norm(self):
   input_shape = (2, 3, 4)
   input_dtype = np.float32
   input_signature = ShapeDtype(input_shape, input_dtype)
   eps = 1e-5
   inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                     input_shape)
   m1 = 11.5  # Mean of this random input.
   v1 = 47.9167  # Variance of this random input.
   layer = normalization.BatchNorm(axis=(0, 1, 2))
   _, _ = layer.initialize_once(input_signature)
   state = layer.state
   onp.testing.assert_allclose(state[0], 0)
   onp.testing.assert_allclose(state[1], 1)
   self.assertEqual(state[2], 0)
   out = layer(inp1)
   state = layer.state
   onp.testing.assert_allclose(state[0], m1 * 0.001)
   onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6)
   self.assertEqual(state[2], 1)
   onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                               rtol=1e-6)
Пример #15
0
 def make_unit_length(self, x, epsilon=1e-6):
     variance = np.mean(x**2, axis=-1, keepdims=True)
     norm_inputs = x / np.sqrt(variance + epsilon)
     return norm_inputs
Пример #16
0
    def single_call(self, qk, v, buckets, rng=None):
        # We use the same vector as both a query and a key.
        seqlen = qk.shape[-2]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = jax.lax.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t,
                                                       ticker,
                                                       dimension=-1)
        _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t)
        sticker = jax.lax.stop_gradient(sticker)
        undo_sort = jax.lax.stop_gradient(undo_sort)

        st = (sticker % seqlen)
        sqk = np.take(qk, st, axis=0)
        sv = np.take(v, st, axis=0)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1))
        bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1]))
        bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1]))
        bq_buckets = bkv_buckets = np.reshape(
            sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1))

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = self.make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster.
        def look_one_back(x):
            if len(x.shape) == 2:
                x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
            else:
                x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0)
            return np.concatenate([x, x_extra], axis=1)

        bk = look_one_back(bk)
        bv = look_one_back(bv)
        bkv_t = look_one_back(bkv_t)
        bkv_buckets = look_one_back(bkv_buckets)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        self_mask = jax.lax.convert_element_type(
            jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32)
        dots = dots - 1e5 * self_mask

        # Mask out attention to other hash buckets.
        if not self._attend_across_buckets:
            bucket_mask = jax.lax.convert_element_type(
                jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]),
                np.float32)
            dots = dots - 1e7 * bucket_mask

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition. (2) When hard_k is set, the code
        # instead masks all but the first occurence of each query-key pair.
        # TODO(kitaev): is one strategy faster or more numerically stable?
        if not self._allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins)
            if not self._attend_across_buckets:
                locs1 = buckets * (self.n_hashes * self.n_bins) + locs1
                locs2 = buckets * (self.n_hashes * self.n_bins) + locs2
            locs = np.moveaxis(
                np.concatenate([
                    np.reshape(locs1, (self.n_hashes, seqlen)),
                    np.reshape(locs2, (self.n_hashes, seqlen)),
                ], 0), 0, -1)  # produces shape (seqlen, 2 * self.n_hashes)
            slocs = np.take(locs, st, axis=0)
            b_locs = np.reshape(
                slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes))
            # Queries always use the primary location (based on locs1).
            b_locs1 = b_locs[:, :, None, :self.n_hashes]
            if self._hard_k > 0:
                range_n_hashes = jax.lax.tie_in(b_locs,
                                                np.arange(self.n_hashes))
                nouse_locs = (range_n_hashes[:, None] >
                              range_n_hashes[None, :])
                nouse_locs = 2 * nouse_locs - 1  # 1 = use, -1 = don't use
                nouse_locs = np.reshape(
                    np.broadcast_to(
                        nouse_locs[:, None, :],
                        (self.n_hashes, self.n_bins, self.n_hashes)),
                    (self.n_hashes * self.n_bins, 1, 1, self.n_hashes))
                b_locs1 = b_locs1 * nouse_locs
            bq_locs = np.broadcast_to(b_locs1,
                                      b_locs.shape[:2] + (2, self.n_hashes))
            bq_locs = np.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = np.sum(jax.lax.convert_element_type(
                jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]),
                np.float32),
                                axis=-1)
            assert dup_counts.shape == dots.shape
            if self._hard_k > 0:
                dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts)
            else:
                dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9))

        # Each query only attends to the top k most relevant keys.
        if self._hard_k > 0:
            b_top_dots = np.sort(dots)[...,
                                       -self._hard_k:]  # Get the top k dots.
            b_top_dots = jax.lax.stop_gradient(b_top_dots)
            s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k))
            top_dots = np.take(s_top_dots, undo_sort, axis=0)

            merged_top_dots = np.moveaxis(
                np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0,
                -1)
            merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1))

            dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k]
            # It's possible to compute the partition function at this point, but right
            # now this codepath isn't set up for backprop, and there might also be
            # issues computing it this way if two dot-products are exactly equal.

            sdots_thresh = dots_thresh[st]
            bdots_thresh = np.reshape(sdots_thresh,
                                      (self.n_hashes * self.n_bins, -1))
            bdots_thresh = jax.lax.stop_gradient(bdots_thresh)

            top_k_mask = jax.lax.convert_element_type(
                dots < bdots_thresh[..., None], np.float32)
            dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask)

        # Softmax.
        dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True)
        dots = np.exp(dots - dots_logsumexp)

        if self._dropout > 0.0:
            # Dropout is broadcast across the bin dimension
            dropout_shape = (1, dots.shape[-2], dots.shape[-1])
            keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout)
            keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
            multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                keep, keep_prob)
            dots = dots * multiplier

        bo = np.matmul(dots, bv)
        so = np.reshape(bo, (-1, bo.shape[-1]))
        slogits = np.reshape(dots_logsumexp, (-1, ))

        def unsort_for_output_impl(so, slogits):
            o = np.take(so, undo_sort, axis=0)
            # Sorting is considerably faster than gather, but first we need to get the
            # XLA compiler to abandon the idea of fusing this sort with the input sort
            # (which introduces a computation cycle and leads to a crash).
            # TODO(kitaev): remove "sticker_" variable if XLA is fixed.
            sticker_ = sticker + jax.lax.convert_element_type(
                slogits[0] > 0, sticker.dtype)
            _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1)
            return o, logits

        def unsort_for_output_vjp(so, slogits):
            """Custom gradient for unsort_for_output."""
            so = jax.lax.stop_gradient(so)
            slogits = jax.lax.stop_gradient(slogits)
            o, logits = unsort_for_output_impl(so, slogits)

            def vjpfun(o_logits_grads):
                so_grad = np.take(o_logits_grads[0], sticker, axis=0)
                # TODO(kitaev): this exists to match the forward pass, but I'm not sure
                # if it's actually required.
                buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type(
                    o_logits_grads[1][0] > 0, buckets_and_t.dtype)
                _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_,
                                                       o_logits_grads[1],
                                                       dimension=-1)
                return (so_grad, slogits_grad)

            return (o, logits), vjpfun

        unsort_for_output = jax.custom_transforms(unsort_for_output_impl)
        jax.defvjp_all(unsort_for_output, unsort_for_output_vjp)
        o, logits = unsort_for_output_impl(so, slogits)

        if self.n_hashes == 1:
            out = o
        else:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits -
                           backend.logsumexp(logits, axis=0, keepdims=True))
            out = np.sum(o * probs, axis=0)

        assert out.shape == v.shape
        return out
Пример #17
0
    def _forward_train_eval(self, inputs, rng):
        (inputs, original_len, n_bins) = self._pad_inputs(inputs)
        q, k, v = inputs
        seqlen = q.shape[-2]
        # q/k/v are n_batch*n_heads, seqlen, d_head
        # Time indices for causal masking.
        t = jax.lax.tie_in(q, np.arange(seqlen))

        # Split off a "bin" axis for chunks of consecutive items.
        bq_t = np.reshape(t, (n_bins, -1))
        bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1]))
        if self._share_qk:
            bk = self.make_unit_length(bq)
        else:
            bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1]))
        bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1]))

        # Allow each chunk to attend within itself, and also one chunk back.
        def look_one_back(x):
            # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis.
            if len(x.shape) == 2:
                x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0)
                return np.concatenate([x, x_extra], axis=1)
            else:
                assert len(x.shape) == 4
                x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]],
                                         axis=1)
                return np.concatenate([x, x_extra], axis=2)

        bkv_t = look_one_back(bq_t)
        bk = look_one_back(bk)
        bv = look_one_back(bv)

        # Dot-product attention.
        dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1])

        # Causal masking based on the time indices.
        mask = jax.lax.convert_element_type(
            jax.lax.lt(bq_t[None, :, :, None], bkv_t[None, :, None, :]),
            np.float32)
        dots = dots - 1e9 * mask

        # Mask out attention to self except when no other targets are available.
        if self._share_qk:
            self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3))
            self_mask = jax.lax.tie_in(dots, self_mask)
            dots = dots - 1e5 * self_mask

        # Softmax.
        dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))

        if self.dropout > 0.0:
            # Dropout is broadcast across the batch+head dimension
            dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1])
            keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
            keep = backend.random.bernoulli(rng, keep_prob, dropout_shape)
            multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                keep, keep_prob)
            dots = dots * multiplier

        bo = np.matmul(dots, bv)

        output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1]))
        assert output.shape == v.shape
        return output[..., :original_len, :]
Пример #18
0
def l2_norm(tree):
    """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
    leaves = tree_flatten(tree)
    return np.sqrt(sum(np.vdot(x, x) for x in leaves))
Пример #19
0
 def _z_score(self, x, mean, variance):
     mu = mean.astype(x.dtype)
     sigma = np.sqrt(variance + self._epsilon).astype(x.dtype)
     return (x - mu) / sigma
Пример #20
0
def LayerNorm(x, weights, epsilon=1e-6, **unused_kwargs):  # pylint: disable=invalid-name
    (scale, bias) = weights
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.mean((x - mean)**2, axis=-1, keepdims=True)
    norm_inputs = (x - mean) / np.sqrt(variance + epsilon)
    return norm_inputs * scale + bias
Пример #21
0
def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.):
    """Returns an initializer for random uniform Kaiming-scaled coefficients."""
    return ScaledInitializer(out_dim, in_dim, 2.0 / np.sqrt(1 + param**2),
                             'fan_in', 'uniform')