Exemplo n.º 1
0
def _layer_norm_params(input_shape, input_dtype, rng):
    """Helper: create layer norm parameters."""
    del input_dtype, rng
    features = input_shape[-1]
    scale = np.ones(features)
    bias = np.zeros(features)
    return (scale, bias)
Exemplo n.º 2
0
def LayerNormParams(input_shape, input_dtype, rng, epsilon=1e-6):
    """Helper: create layer norm parameters."""
    del input_dtype, rng, epsilon
    features = input_shape[-1]
    scale = np.ones(features)
    bias = np.zeros(features)
    return (scale, bias)
Exemplo n.º 3
0
def _layer_norm_new_params(input_shape, rng, epsilon=1e-6):  # pylint: disable=invalid-name
    """Helper: create layer norm parameters."""
    del rng, epsilon
    features = input_shape[-1]
    scale = np.ones(features)
    bias = np.zeros(features)
    return (scale, bias)
Exemplo n.º 4
0
def EncoderDecoderMask(x, **unused_kwargs):
    """Make encoder-decoder mask from a padding mask and decoder input."""
    (padding_mask, decoder_input) = x
    padding_mask = np.reshape(
        padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1]))
    # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len].
    return padding_mask + np.ones((1, 1, decoder_input.shape[1], 1))
Exemplo n.º 5
0
 def call(self, inputs, params=(), rng=None, **kwargs):
   del params
   q, k, v = inputs
   mask_size = q.shape[-2]
   mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
   res = tl.DotProductAttention(
       q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng)
   return res
Exemplo n.º 6
0
def _batch_norm_new_params(input_shape, rng, axis=(0, 1, 2),
                           center=True, scale=True, **kwargs):
  """Helper to initialize batch norm params."""
  del rng, kwargs
  axis = (axis,) if np.isscalar(axis) else axis
  shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
  beta = np.zeros(shape, dtype='float32') if center else ()
  gamma = np.ones(shape, dtype='float32') if scale else ()
  return (beta, gamma)
Exemplo n.º 7
0
    def new_parameters(self, input_shape, input_dtype, rng):
        """Helper to initialize batch norm params."""
        del input_dtype, rng
        axis = self._axis
        axis = (axis, ) if np.isscalar(axis) else axis
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        beta = np.zeros(shape, dtype='float32') if self._center else ()
        gamma = np.ones(shape, dtype='float32') if self._scale else ()

        def get_stats_axis(i, d):
            if i in axis:
                return 1
            else:
                return d

        stats_shape = tuple(
            get_stats_axis(i, d) for i, d in enumerate(input_shape))
        running_mean = np.zeros(stats_shape, dtype=np.float32)
        running_var = np.ones(stats_shape, dtype=np.float32)
        num_batches = np.zeros((), dtype=np.int32)
        return (beta, gamma), (running_mean, running_var, num_batches)
Exemplo n.º 8
0
 def call(self, inputs, params=(), state=(), rng=None, **kwargs):
   del params
   q, k, v = inputs
   mask_size = q.shape[-2]
   # Not all backends define np.tril. However, using onp.tril is inefficient in
   # that it creates a large global constant. TODO(kitaev): try to find an
   # alternative that works across all backends.
   if backend.get_name() == 'jax':
     mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
   else:
     mask = onp.tril(onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
   res = DotProductAttention(
       q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng)
   return res, state
Exemplo n.º 9
0
def ChunkedAttentionSelector(x, params, selector=None, **kwargs):
    """Select which chunks to attend to in chunked attention.

  Args:
    x: inputs, a list of elements of the form (q, k, v), mask for each chunk.
    params: parameters (unused).
    selector: a function from chunk_number -> list of chunk numbers that says
      which other chunks should be appended to the given one (previous if None).
    **kwargs: unused other arguments.

  Returns:
    a list of elements of the form (q, k', v', mask') where k', v' and mask' are
    concatenations of k, v and identity-extended masks from selected chunks.
  """
    del params, kwargs
    selector = selector or (lambda x: [] if x < 1 else [x - 1])
    triples, masks = zip(*x)
    (queries, keys, values) = zip(*triples)
    result = []
    for i in range(len(x)):
        selected = selector(i)
        # Since keys and values are [batch, length, depth] we concatenate on axis=1.
        # We also always include the current key or value at the end.
        new_key_list = [keys[j] for j in selected]
        new_key = np.concatenate(new_key_list + [keys[i]], axis=1)
        new_value = np.concatenate([values[j] for j in selected] + [values[i]],
                                   axis=1)
        # Masks are (1, query-len, key-len) so we concatenate on axis=2.
        new_mask_shapes = [(1, queries[i].shape[1], key.shape[1])
                           for key in new_key_list]
        cur_mask = masks[i]
        # Masks are all-1 for the added chunks (no masking).
        new_mask_list = [
            np.ones(s, dtype=cur_mask.dtype) for s in new_mask_shapes
        ]
        # We still use the current (often causal) mask for the final chunk.
        new_mask = np.concatenate(new_mask_list + [cur_mask], axis=2)
        result.append((queries[i], new_key, new_value, new_mask))
    return tuple(result)
Exemplo n.º 10
0
  def forward(self, inputs, params=(), state=(), rng=None, **kwargs):
    del params
    q, k, v = inputs
    if self._mode in ('train', 'eval'):
      mask_size = q.shape[-2]
      # Not all backends define np.tril. However, using onp.tril is inefficient
      # in that it creates a large global constant. TODO(kitaev): try to find an
      # alternative that works across all backends.
      if backend.get_name() == 'jax':
        mask = np.tril(
            np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
      else:
        mask = onp.tril(
            onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0)
    else:
      assert self._mode == 'predict'
      state = _fast_inference_update_state(inputs, state)
      (k, v, mask, _) = state

    res = DotProductAttention(
        q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng)
    return res, state
Exemplo n.º 11
0
 def rescale(outputs, inputs):
     one = np.ones(inputs.shape[1:-1], dtype=inputs.dtype)
     window_sizes = lax.reduce_window(one, 0., lax.add, dims,
                                      spatial_strides, padding)
     return outputs / window_sizes[..., np.newaxis]
Exemplo n.º 12
0
    return init


def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2)):
    """An initializer function for random Glorot-scaled coefficients."""
    def init(rng, shape):
        fan_in, fan_out = shape[in_dim], shape[out_dim]
        size = onp.prod(onp.delete(shape, [in_dim, out_dim]))
        std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
        return (std * backend.random.normal(rng, shape)).astype('float32')

    return init


zeros = lambda rng, shape: np.zeros(shape, dtype='float32')
ones = lambda rng, shape: np.ones(shape, dtype='float32')

# Layers

# Each layer constructor function returns an (init_fun, apply_fun) pair, where
#   init_fun: takes an input shape and returns an (output_shape, params) pair,
#   apply_fun: takes params, inputs, and an rng key and applies the layer.


def Dense(out_dim, W_init=glorot(), b_init=randn()):
    """Layer constructor function for a dense (fully-connected) layer."""
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        w, b = W_init(rng,
                      (input_shape[-1], out_dim)), b_init(rng, (out_dim, ))
        return output_shape, (w, b)
Exemplo n.º 13
0
 def init_fun(_, input_shape):
     features = input_shape[-1]
     scale = np.ones(features)
     bias = np.zeros(features)
     return input_shape, (scale, bias)