Пример #1
0
def _layer_norm_weights(input_signature, **unused_kwargs):
  """Helper: create layer norm parameters."""
  features = input_signature.shape[-1]
  scale = np.ones(features, dtype=input_signature.dtype)
  bias = np.zeros(features, dtype=input_signature.dtype)
  weights = (scale, bias)
  return weights
Пример #2
0
def _layer_norm_weights(input_signature):
  """Helper: create layer norm parameters."""
  features = input_signature.shape[-1]
  scale = np.ones(features)
  bias = np.zeros(features)
  weights = (scale, bias)
  return weights
Пример #3
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        del weights
        q, k, v = inputs
        if self._mode in ('train', 'eval'):
            mask_size = q.shape[-2]
            # Not all backends define np.tril. However, using onp.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if math.backend_name() == 'jax':
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=onp.bool_),
                               k=0)
            else:
                mask = onp.tril(onp.ones((1, mask_size, mask_size),
                                         dtype=onp.bool_),
                                k=0)
        else:
            assert self._mode == 'predict'
            state = _fast_inference_update_state(inputs, state)
            (k, v, mask, _) = state

        res = DotProductAttention(q,
                                  k,
                                  v,
                                  mask,
                                  dropout=self._dropout,
                                  mode=self._mode,
                                  rng=rng)
        return res, state
Пример #4
0
    def forward_with_state(self, inputs, weights, state, rng):
        del weights
        q, k, v = inputs

        if self._mode == 'predict':
            state = _fast_inference_update_state(inputs, state)
            (k, v, mask, _) = state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if math.backend_name() == 'jax':
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res = DotProductAttention(q,
                                  k,
                                  v,
                                  mask,
                                  dropout=self._dropout,
                                  mode=self._mode,
                                  rng=rng)
        return res, state
Пример #5
0
 def new_weights_and_state(self, input_signature):
   """Helper to initialize batch norm weights."""
   axis = self._axis
   axis = (axis,) if np.isscalar(axis) else axis
   input_shape = input_signature.shape
   shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
   beta = np.zeros(shape, dtype='float32') if self._center else ()
   gamma = np.ones(shape, dtype='float32') if self._scale else ()
   def get_stats_axis(i, d):
     if i in axis:
       return 1
     else:
       return d
   stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape))
   running_mean = np.zeros(stats_shape, dtype=np.float32)
   running_var = np.ones(stats_shape, dtype=np.float32)
   n_batches = np.zeros((), dtype=np.int64)
   weights = (beta, gamma)
   state = (running_mean, running_var, n_batches)
   return weights, state
Пример #6
0
  def new_weights(self, input_signature):
    # Usually (B, W, H, C)
    shape = input_signature.shape
    num_channels = shape[-1]

    gamma = np.ones((num_channels,), dtype=np.float32)
    beta = np.zeros((num_channels,), dtype=np.float32)

    epsilon_l = base.EMPTY_WEIGHTS
    if self._learn_epsilon:
      epsilon_l = (self._init_learnt_epsilon,)

    return gamma, beta, epsilon_l
Пример #7
0
    def init_weights_and_state(self, input_signature):
        """Helper to initialize batch norm weights and state."""
        axis = self._axis
        axis = (axis, ) if jnp.isscalar(axis) else axis
        input_shape = input_signature.shape
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        # TODO(jonni): Should beta and gamma match the dtype in the input signature?
        beta = jnp.zeros(shape, dtype='float32') if self._center else ()
        gamma = jnp.ones(shape, dtype='float32') if self._scale else ()

        def get_stats_axis(i, d):
            if i in axis:
                return 1
            else:
                return d

        stats_shape = tuple(
            get_stats_axis(i, d) for i, d in enumerate(input_shape))
        running_mean = jnp.zeros(stats_shape, dtype=jnp.float32)
        running_var = jnp.ones(stats_shape, dtype=jnp.float32)
        n_batches = jnp.zeros((), dtype=jnp.int64)
        self.weights = (beta, gamma)
        self.state = (running_mean, running_var, n_batches)
Пример #8
0
 def new_weights(self, input_signature):
   features = input_signature.shape[-1]
   scale = jnp.ones(features, dtype=input_signature.dtype)
   bias = jnp.zeros(features, dtype=input_signature.dtype)
   return scale, bias
Пример #9
0
 def init_weights_and_state(self, input_signature):
     self.weights = jnp.zeros((2, 3))
     self.state = jnp.ones(input_signature.shape)
Пример #10
0
 def init_weights_and_state(self, input_signature):
     features = input_signature.shape[-1]
     scale = jnp.ones(features, dtype=input_signature.dtype)
     bias = jnp.zeros(features, dtype=input_signature.dtype)
     self.weights = scale, bias