def __call__(self, input_qkv): cfg = self.config cfg.max_len % cfg.max_seg_len == 0 bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] num_seg = cfg.max_len // cfg.max_seg_len x_sqr = input_qkv.reshape([bsize, num_seg, cfg.max_seg_len, input_qkv.shape[-1]]) q_row_local, key_row_local, value_row_local, head_dim = get_qkv(cfg, x_sqr) local_logits = jnp.einsum('...qhd,...khd->...qhk', q_row_local, key_row_local) row_probs = jax.nn.softmax(local_logits) if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') row_probs = dropatt(row_probs, dropout_rng, 1 - cfg.attention_dropout_rate) row_attn_out = jnp.einsum('...qhk,...khd->...qhd', row_probs, value_row_local) key_row = DenseGeneral(features=input_qkv.shape[-1], axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(row_attn_out) key_row = nn.Dropout(rate=cfg.dropout_rate)(key_row, deterministic=cfg.deterministic) key_row = key_row + x_sqr key_row = nn.LayerNorm(dtype=cfg.dtype)(key_row) key_row = DenseGeneral(axis=-1, features=(cfg.num_heads, head_dim), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(key_row) idx_cols = jnp.arange(cfg.max_seg_len) local_mask = nn.make_attention_mask(idx_cols, idx_cols, jnp.less, extra_batch_dims=1) local_mask = jnp.expand_dims(local_mask, axis=-2) * -1e10 local_logits = local_logits + local_mask global_logits = jnp.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row) idx_rows = jnp.arange(num_seg) global_mask = nn.make_attention_mask(idx_rows, idx_rows, jnp.less_equal) global_mask = global_mask[:, :, jnp.newaxis, jnp.newaxis, :] * -1e10 global_logits = global_logits + global_mask joint_logits = jnp.concatenate((local_logits, global_logits), axis=-1) attn_probs = jax.nn.softmax(joint_logits, axis=-1) local_att, global_att = jnp.split(attn_probs, [cfg.max_seg_len], axis=-1) if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') local_att = dropatt(local_att, dropout_rng, 1 - cfg.attention_dropout_rate) local_merged = jnp.einsum('bsqhk,bskhd->bsqhd', local_att, value_row_local) global_merged = jnp.einsum('bqlhv,bvlhd->bqlhd', global_att, row_attn_out) joint_merged = jnp.reshape(local_merged + global_merged, [bsize, cfg.max_len, cfg.num_heads, head_dim]) x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def __call__(self, inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Returns: output of shape `[batch_sizes..., length, features]`. """ features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = partial(DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = ( cached_key.value.shape) # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' 'expected query shape %s instead got %s.' % (expected_shape, query.shape)) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0, ) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length))) # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( mask > 0, jnp.full(mask.shape, 0.).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) else: attention_bias = None dropout_rng = None if not self.deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = self.attention_fn(query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=self.deterministic, dtype=self.dtype, precision=self.precision) # back to the original inputs dimensions out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) return out
def __call__(self, input_qkv): cfg = self.config cfg.max_len % cfg.max_seg_len == 0 bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] query, key, value, head_dim = get_qkv(cfg, input_qkv) num_seg = cfg.max_len // cfg.max_seg_len cur_query = query.reshape( [-1, cfg.max_seg_len, query.shape[-2], query.shape[-1]]) merged_query = jnp.max(cur_query, axis=1, keepdims=True) * jnp.sqrt(head_dim) cur_key = key.reshape( [-1, cfg.max_seg_len, key.shape[-2], key.shape[-1]]) cur_value = value.reshape( [-1, cfg.max_seg_len, value.shape[-2], value.shape[-1]]) dropout_rng = None if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') s = dot_product_attention(merged_query, cur_key, cur_value, dropout_rng=dropout_rng, dropout_rate=cfg.attention_dropout_rate, broadcast_dropout=False, deterministic=cfg.deterministic, dtype=cfg.dtype) span_val = jnp.reshape(s, [bsize, -1, s.shape[-2], s.shape[-1]]) span_key = jnp.max(cur_key, axis=1, keepdims=True) # (bsize, n_seg, n_head, dim_per_head) span_key = jnp.reshape( span_key, [bsize, -1, span_key.shape[-2], span_key.shape[-1]]) local_mask = make_causal_mask(cur_query, length_axis=1).transpose([0, 2, 1, 3]) local_bias = lax.select( local_mask > 0, jnp.full(local_mask.shape, 0.).astype(cfg.dtype), jnp.full(local_mask.shape, -1e10).astype(cfg.dtype)) # (bsize * n_seg, seg_len, n_head, seg_len) local_logits = jnp.einsum('...qhd,...khd->...qhk', cur_query, cur_key) + local_bias local_logits = jnp.reshape(local_logits, [bsize, -1, cfg.num_heads, cfg.max_seg_len]) idx = jnp.broadcast_to(jnp.arange(span_key.shape[1], dtype=jnp.int32), span_key.shape[:2]) prev_mask = nn.make_attention_mask(idx, idx, jnp.greater, extra_batch_dims=0, dtype=jnp.float32).transpose( [0, 2, 1, 3]) prev_mask = jnp.repeat(prev_mask, cfg.max_seg_len, axis=-3) prev_bias = lax.select( prev_mask > 0, jnp.full(prev_mask.shape, 0.).astype(cfg.dtype), jnp.full(prev_mask.shape, -1e10).astype(cfg.dtype)) # (bsize, max_len, n_head, num_segs) prev_logits = jnp.einsum('...qhd,...khd->...qhk', query, span_key) + prev_bias joint_logits = jnp.concatenate((local_logits, prev_logits), axis=-1) # (bsize x max_len, n_head, seg_len + num_segs) attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype) local_att, prev_att = jnp.split(attn_weights, [cfg.max_seg_len], axis=-1) local_att = local_att.reshape( [bsize * num_seg, cfg.max_seg_len, cfg.num_heads, cfg.max_seg_len]) local_merged = jnp.einsum('...qhk,...khd->...qhd', local_att, cur_value) prev_merged = jnp.einsum('...qhk,...khd->...qhd', prev_att, span_val) joint_merged = jnp.reshape(local_merged, prev_merged.shape) + prev_merged x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def __call__(self, input_qkv): cfg = self.config log_len = log_2_ceil(cfg.max_len - 1) bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] query, key, value, head_dim = get_qkv(cfg, input_qkv) joint_logits = [] list_vals = [] for l in range(log_len): ctx_len = 2**l last_pos = cfg.max_len - cfg.max_len % ctx_len num_ctx = cfg.max_len // ctx_len if l == 0: span_key = jnp.reshape(key, [-1, 1, cfg.num_heads, head_dim]) span_val = value.reshape(span_key.shape) self_logits = jnp.expand_dims(jnp.sum(query * key, axis=-1), -1) joint_logits.append(self_logits) else: left_query = query[:, :last_pos, :, :].reshape( [-1, ctx_len, cfg.num_heads, head_dim]) span_query = jnp.max(left_query, axis=1, keepdims=True) left_key = key[:, :last_pos, :, :].reshape(left_query.shape) left_val = value[:, :last_pos, :, :].reshape(left_query.shape) span_val = dot_product_attention( span_query * jnp.sqrt(head_dim), left_key, left_val, dropout_rng=self.get_dropout_png(cfg), dropout_rate=cfg.attention_dropout_rate, broadcast_dropout=False, deterministic=cfg.deterministic, dtype=cfg.dtype) span_key = jnp.max(left_key, axis=1, keepdims=True) rolled_q = jnp.roll(query, -ctx_len, axis=1)[:, :last_pos, :, :].reshape( [-1, ctx_len, cfg.num_heads, head_dim]) rolled_mask = jnp.concatenate( [(jnp.arange(cfg.max_len - ctx_len) // ctx_len) % 2, jnp.ones(last_pos + ctx_len - cfg.max_len, dtype=jnp.int32)], axis=0) rolled_mask = jnp.reshape(rolled_mask, [1, -1, 1, 1]) rolled_logits = jnp.einsum('...qhd,...khd->...qhk', rolled_q, span_key) # bsize, last_pos, h, 1 rolled_logits = jnp.reshape( rolled_logits, [bsize, -1, cfg.num_heads, 1 ]) + rolled_mask.astype(rolled_q.dtype) * -1e9 orig_logits = jnp.pad(rolled_logits, [(0, 0), (0, cfg.max_len - last_pos), (0, 0), (0, 0)], constant_values=-1e9) orig_logits = jnp.roll(orig_logits, ctx_len, axis=1) joint_logits.append(orig_logits) list_vals.append(span_val) joint_logits = jnp.concatenate(joint_logits, axis=-1) attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype) local_weights = jnp.split(attn_weights, log_len + 1, axis=-1) local_weighted_sums = [] joint_merged = local_weights[0] * value for l in range(log_len): ctx_len = 2**l last_pos = cfg.max_len - cfg.max_len % ctx_len num_ctx = cfg.max_len // ctx_len rolled_w = jnp.roll(local_weights[l + 1], -ctx_len, axis=1)[:, :last_pos, :, :].reshape( bsize * num_ctx, ctx_len, cfg.num_heads, 1) rolled_v = jnp.reshape(rolled_w * list_vals[l], [bsize, -1, cfg.num_heads, head_dim]) rolled_v = jnp.pad(rolled_v, [(0, 0), (0, cfg.max_len - last_pos), (0, 0), (0, 0)]) orig_v = jnp.roll(rolled_v, ctx_len, axis=1) joint_merged = joint_merged + orig_v x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def __call__(self, input_qkv): cfg = self.config cfg.max_len % cfg.max_seg_len == 0 assert input_qkv.ndim == 3 bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] qkv_features = cfg.qkv_dim or input_qkv.shape[-1] assert qkv_features % cfg.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // cfg.num_heads dense = partial(DenseGeneral, axis=-1, features=(cfg.num_heads, head_dim), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False) query, key, value = (dense(dtype=cfg.dtype, name='query')(input_qkv) / jnp.sqrt(head_dim), dense(dtype=cfg.dtype, name='key')(input_qkv), dense(dtype=cfg.dtype, name='value')(input_qkv)) num_seg = cfg.max_len // cfg.max_seg_len ################## cur_query = query.reshape( [bsize, num_seg, cfg.max_seg_len, cfg.num_heads, head_dim]) cur_key = key.reshape( [bsize, num_seg, cfg.max_seg_len, cfg.num_heads, head_dim]) cur_value = value.reshape( [bsize, num_seg, cfg.max_seg_len, cfg.num_heads, head_dim]) num_attn_dims = 2 col_logit_expr = 'BSUNK,BTUNK->BUNST' col_attn_expr = 'BUNST,BTUNK->BSUNK' col_strict_mask = make_causal_mask( cur_query, length_axis=1, strict=True ) # strict lower triangular matrix so that the token won't repeatedly attend to itself col_strict_mask = jnp.expand_dims(col_strict_mask, axis=1) # (bsize, 1, 1, num_seg, num_seg) col_strict_bias = lax.select( col_strict_mask > 0, jnp.full(col_strict_mask.shape, 0.).astype(cfg.dtype), jnp.full(col_strict_mask.shape, -1e10).astype(cfg.dtype)) row_logit_expr = 'BUSNK,BUTNK->BUNST' row_attn_expr = 'BUNST,BUTNK->BUSNK' row_mask = make_causal_mask(cur_query, length_axis=2)[:, 0:1, :, :, :] # (bsize, 1, 1, max_seg_len, max_seg_len) row_bias = lax.select( row_mask > 0, jnp.full(row_mask.shape, 0.).astype(cfg.dtype), jnp.full(row_mask.shape, -1e10).astype(cfg.dtype)) col_logits = jnp.einsum(col_logit_expr, cur_query, cur_key) + col_strict_bias # (bsize, max_seg_len, num_head, num_seg, num_seg) row_logits = jnp.einsum(row_logit_expr, cur_query, cur_key) + row_bias # (bsize, num_seg, num_head, max_seg_len, max_seg_len) ############################### col_up2down_query = jax.lax.cummax(cur_query, axis=1) col_up2down_key = shift_right(jax.lax.cummax(cur_key, axis=1), axis=1) # shift down in some sense col_mask = make_causal_mask(cur_query, length_axis=1) col_mask = jnp.expand_dims(col_mask, axis=1) col_bias = lax.select( col_mask > 0, jnp.full(col_mask.shape, 0.).astype(cfg.dtype), jnp.full(col_mask.shape, -1e10).astype(cfg.dtype)) col_up2down_logits = jnp.einsum(col_logit_expr, col_up2down_query, cur_key) + col_bias col_up2down_attn_weights = jax.nn.softmax(col_up2down_logits).astype( cfg.dtype) col_up2down_summary = jnp.einsum(col_attn_expr, col_up2down_attn_weights, cur_value) col_up2down_summary = shift_right(col_up2down_summary, axis=1) # shift down in some sense row_only_myself_mask = jnp.expand_dims(jnp.eye(cur_query.shape[2]), (0, 1, 2)) row_without_myself_bias = lax.select( row_only_myself_mask == 0, jnp.full(row_only_myself_mask.shape, 0.).astype(cfg.dtype), jnp.full(row_only_myself_mask.shape, -1e10).astype(cfg.dtype) ) # attend to all tokens in the previous row except for the token right up to the token in the previous token because this is already taken care of in the local col attention all_maskout = jnp.full(row_only_myself_mask.shape, -1e10).astype(cfg.dtype) row_without_myself_bias = jnp.concatenate( [all_maskout] + [row_without_myself_bias] * (cur_query.shape[1] - 1), axis=1 ) # the first row also has no previous row to attend, so just mask out all logits calculated here previous_row_logits = jnp.einsum( row_logit_expr, cur_query, col_up2down_key) + row_without_myself_bias row_left2right_query = jax.lax.cummax(cur_query, axis=2) row_left2right_key = shift_right(jax.lax.cummax(cur_key, axis=2), axis=2) row_left2right_logits = jnp.einsum( row_logit_expr, row_left2right_query, cur_key) + row_bias row_left2right_attn_weights = jax.nn.softmax( row_left2right_logits).astype(cfg.dtype) row_left2right_summary = jnp.einsum(row_attn_expr, row_left2right_attn_weights, cur_value) row_left2right_summary = shift_right(row_left2right_summary, axis=2) all_maskout = jnp.full(col_strict_bias.shape, -1e10).astype(cfg.dtype) col_strict_without_first_bias = jnp.concatenate( [all_maskout] + [col_strict_bias] * (cur_query.shape[2] - 1), axis=1) top_left_col_logits = jnp.einsum( col_logit_expr, cur_query, row_left2right_key) + col_strict_without_first_bias ################################## row_right2left_query = jax.lax.cummax(cur_query, axis=2, reverse=True) row_right2left_key = shift_left(jax.lax.cummax(cur_key, axis=2, reverse=True), axis=2) row_strict_mask = make_causal_mask(cur_query, length_axis=2, strict=True)[:, 0:1, :, :, :] # (bsize, 1, 1, max_seg_len, max_seg_len) row_upper_bias = lax.select( row_strict_mask == 0, jnp.full(row_strict_mask.shape, 0.).astype(cfg.dtype), jnp.full(row_strict_mask.shape, -1e10).astype(cfg.dtype) ) # an upper triangular matrix since we attend all tokens on the right row_right2left_logits = jnp.einsum( row_logit_expr, row_right2left_query, cur_key) + row_upper_bias row_right2left_attn_weights = jax.nn.softmax( row_right2left_logits).astype(cfg.dtype) row_right2left_summary = jnp.einsum(row_attn_expr, row_right2left_attn_weights, cur_value) row_right2left_summary = shift_left(row_right2left_summary, axis=2) col_strict_without_last_bias = jnp.concatenate( [col_strict_bias] * (cur_query.shape[2] - 1) + [all_maskout], axis=1) top_right_col_logits = jnp.einsum( col_logit_expr, cur_query, row_right2left_key) + col_strict_without_last_bias #### joint_logits = jnp.concatenate( (col_logits.transpose([0, 3, 2, 1, 4]), row_logits, previous_row_logits, top_left_col_logits.transpose([ 0, 3, 2, 1, 4 ]), top_right_col_logits.transpose([0, 3, 2, 1, 4])), axis=-1 ) # follow row, row first, the shape should be (bsize, num_seg, num_head, max_seg_len, num_seg+max_seg_len+max_seg_len+num_seg+num_seg) attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype) col_att, row_att, previous_row_att, top_left_col_att, top_right_col_att = jnp.split( attn_weights, [ num_seg, num_seg + cfg.max_seg_len, num_seg + cfg.max_seg_len * 2, num_seg * 2 + cfg.max_seg_len * 2 ], axis=-1) col_att = col_att.transpose([0, 3, 2, 1, 4]) top_left_col_att = top_left_col_att.transpose([0, 3, 2, 1, 4]) top_right_col_att = top_right_col_att.transpose([0, 3, 2, 1, 4]) col_merged = jnp.einsum(col_attn_expr, col_att, cur_value) row_merged = jnp.einsum(row_attn_expr, row_att, cur_value) previous_row_merged = jnp.einsum(row_attn_expr, previous_row_att, col_up2down_summary) top_left_merged = jnp.einsum(col_attn_expr, top_left_col_att, row_left2right_summary) top_right_merged = jnp.einsum(col_attn_expr, top_right_col_att, row_right2left_summary) joint_merged = (col_merged + row_merged + previous_row_merged + top_left_merged + top_right_merged).reshape([ bsize, num_seg * cfg.max_seg_len, cfg.num_heads, head_dim ]) x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def __call__(self, inputs_q, inputs_kv, mask_k=None): """Multi-head dot product attention with 1D attention mask. Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask_k: attention mask of shape `[batch_sizes..., num_heads, key/value_length]`. Returns: output of shape `[batch_sizes..., length, features]`. """ features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = functools.partial(DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # Project inputs_q to multi-headed q/k/v # Dimensions are then [batch, ..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) # TODO(tomprom): Enforce use of unidirectional mask for decoding if self.decode: raise NotImplementedError # Convert the boolean attention mask to an attention bias. if mask_k is not None: # Element-wise multiply the key mask by the value matrix # Dim: batch, sequence length, heads, qkv dimension (per head) value = jnp.einsum('...l,...lhd->...lhd', mask_k, value) else: # If no mask is provided, leave as is. pass dropout_rng = None if not self.deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # Apply Attention x = self.attention_fn(query, key, value, bias=None, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=self.deterministic, dtype=self.dtype, precision=self.precision) # Back to the original inputs' dimensions. out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) return out
def __call__(self, inputs_q, inputs_kv, mask=None, deterministic=None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ assert inputs_q.ndim == 3 and inputs_kv.ndim == 3 if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. deterministic = merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = partial(DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query', features=(self.num_repeat, self.num_heads, head_dim))(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) key = jnp.expand_dims(key, -3) value = jnp.expand_dims(value, -3) key = jnp.tile(key, self.to_tile_shape) value = jnp.tile(value, self.to_tile_shape) query = jnp.swapaxes(query, -3, -4) key = jnp.swapaxes(key, -3, -4) value = jnp.swapaxes(value, -3, -4) ''' query shape: (batch_size, num_repeat, query_seq_len, num_head, emb_dim) kv shape: (batch_size, num_repeat, kv_seq_len, num_head, emb_dim) ''' # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( mask > 0, jnp.full(mask.shape, 0.).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = self.attention_fn(query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) out = jnp.swapaxes(out, -2, -3) ''' swap out from (batch_size, num_repeat, seq_len, emb_dim) to (batch_size, seq_len, num_repeat, emb_dim) ''' return out
def __call__(self, inputs_q, inputs_kv, mask=None, deterministic=None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Attention weights are masked out if their corresponding mask value is `False`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. deterministic = merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = functools.partial(DenseGeneral, axis=-1, dtype=self.dtype, param_dtype=self.param_dtype, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(name='query')(inputs_q), dense(name='key')(inputs_kv), dense(name='value')(inputs_kv)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = ( cached_key.value.shape) # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' 'expected query shape %s instead got %s.' % (expected_shape, query.shape)) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0, ) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length))) dropout_rng = None if not deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') toeplitz_params = self.param('toeplitz_params', ones, ( query.shape[-2], 4 * self.nb_x_patches * self.nb_y_patches, )) # apply attention x = self.attention_fn(query, key, value, toeplitz_params, mask=mask, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision, nb_x_patches=self.nb_x_patches, nb_y_patches=self.nb_y_patches) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, name='out')(x) return out