def __call__(self, inputs, deterministic: Optional[bool] = None, rng=None): """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. deterministic: if false the inputs are scaled by `1 / (1 - rate)` and masked, whereas if true, no mask is applied and the inputs are returned as is. rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be used. Returns: The masked inputs reweighted to preserve mean. """ deterministic = merge_param('deterministic', self.deterministic, deterministic) if self.rate == 0.: return inputs keep_prob = 1. - self.rate if deterministic: return inputs else: if rng is None: rng = self.make_rng('dropout') broadcast_shape = list(inputs.shape) for dim in self.broadcast_dims: broadcast_shape[dim] = 1 mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
def __call__(self, inputs_q, inputs_kv, mask=None, custom_relative_position=None, deterministic=None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Attention weights are masked out if their corresponding mask value is `False`. custom_relative_position: relative positions tensor `[batch_sizes..., query_length, key/value_length]' deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. deterministic = module.merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = functools.partial(linear.DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) relative_attention_embed = linear.Embed( num_embeddings=self.num_relative_position_buckets, features=self.num_heads, embedding_init=initializers.normal(stddev=1.0), dtype=self.dtype) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) if custom_relative_position is None: query_length = inputs_q.shape[-2] key_length = inputs_kv.shape[-2] context_position = jnp.arange(query_length, dtype=jnp.int32)[:, None] memory_position = jnp.arange(key_length, dtype=jnp.int32)[None, :] relative_position = memory_position - context_position relative_position_bucket = make_relative_position_bucket( relative_position, bidirectional=self.bidirectional, num_buckets=self.num_relative_position_buckets, max_distance=self.max_distance) bias = relative_attention_embed(relative_position_bucket) bias = bias.transpose((2, 0, 1)) # Expand batch dimensions. bias = jnp.broadcast_to(bias, (1, ) * len(inputs_q.shape[:-2]) + bias.shape) else: relative_position = custom_relative_position relative_position_bucket = make_relative_position_bucket( relative_position, bidirectional=self.bidirectional, num_buckets=self.num_relative_position_buckets, max_distance=self.max_distance) bias = relative_attention_embed(relative_position_bucket) permute = tuple( map(lambda i: len(inputs_q.shape) + 1 + i, (-1, -3, -2))) bias = bias.transpose( tuple(range(len(inputs_q.shape[:-2]))) + permute) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = ( cached_key.value.shape) # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' 'expected query shape %s instead got %s.' % (expected_shape, query.shape)) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0, ) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = attention.combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length))) bias = lax.dynamic_slice(bias, (0, 0, cur_index, 0), (1, self.num_heads, 1, max_length)) # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias bias += lax.select(mask > 0, jnp.full(mask.shape, 0.).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) dropout_rng = None if not deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = attention.dot_product_attention( query, key, value, bias=bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = linear.DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) return out
def __call__(self, inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None, deterministic: Optional[bool] = None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. deterministic = merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = partial(DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = ( cached_key.value.shape) # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' 'expected query shape %s instead got %s.' % (expected_shape, query.shape)) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0, ) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length))) # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( mask > 0, jnp.full(mask.shape, 0.).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = self.attention_fn(query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) return out
def __call__(self, x, use_running_average: Optional[bool] = None): """Normalizes the input using batch statistics. Args: x: the input to be normalized. use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. Returns: Normalized inputs (the same shape as inputs). """ use_running_average = merge_param( 'use_running_average', self.use_running_average, use_running_average) x = jnp.asarray(x, jnp.float32) axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) # we detect if we're in initialization via empty variable tree. initializing = not self.has_variable('batch_stats', 'mean') ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) if use_running_average: mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if self.axis_name is not None and not initializing: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean( concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) var = mean2 - lax.square(mean) if not initializing: ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var y = x - mean.reshape(feature_shape) mul = lax.rsqrt(var + self.epsilon) if self.use_scale: scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) mul = mul * scale y = y * mul if self.use_bias: bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) y = y + bias return jnp.asarray(y, self.dtype)
def __call__(self, inputs_q, inputs_kv, mask=None, deterministic=None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ assert inputs_q.ndim == 3 and inputs_kv.ndim == 3 if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. deterministic = merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = partial(DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query', features=(self.num_repeat, self.num_heads, head_dim))(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) key = jnp.expand_dims(key, -3) value = jnp.expand_dims(value, -3) key = jnp.tile(key, self.to_tile_shape) value = jnp.tile(value, self.to_tile_shape) query = jnp.swapaxes(query, -3, -4) key = jnp.swapaxes(key, -3, -4) value = jnp.swapaxes(value, -3, -4) ''' query shape: (batch_size, num_repeat, query_seq_len, num_head, emb_dim) kv shape: (batch_size, num_repeat, kv_seq_len, num_head, emb_dim) ''' # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( mask > 0, jnp.full(mask.shape, 0.).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = self.attention_fn(query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) out = jnp.swapaxes(out, -2, -3) ''' swap out from (batch_size, num_repeat, seq_len, emb_dim) to (batch_size, seq_len, num_repeat, emb_dim) ''' return out
def __call__(self, x, use_running_average: Optional[bool] = None): """Normalizes the input using batch statistics. NOTE: During initialization (when parameters are mutable) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don't need to match that of the actual input distribution and the reduction axis (set with `axis_name`) does not have to exist. Args: x: the input to be normalized. use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. Returns: Normalized inputs (the same shape as inputs). """ use_running_average = merge_param('use_running_average', self.use_running_average, use_running_average) x = jnp.asarray(x, jnp.float32) axis = self.axis if isinstance(self.axis, tuple) else (self.axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) # see NOTE above on initialization behavior initializing = self.is_mutable_collection('params') ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) if use_running_average: mean, var = ra_mean.value, ra_var.value else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if self.axis_name is not None and not initializing: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) var = mean2 - lax.square(mean) if not initializing: ra_mean.value = self.momentum * ra_mean.value + ( 1 - self.momentum) * mean ra_var.value = self.momentum * ra_var.value + ( 1 - self.momentum) * var y = x - mean.reshape(feature_shape) mul = lax.rsqrt(var + self.epsilon) if self.use_scale: scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) mul = mul * scale y = y * mul if self.use_bias: bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) y = y + bias return jnp.asarray(y, self.dtype)