def call(self, context, x, losses=None): """Call the layer.""" params = mtf.layers.multihead_attention_params(context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) if context.mode == "incremental": prev_k, prev_v = context.get_states(2) y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( x, prev_k, prev_v, context.position, params=params) context.record_new_states([new_k, new_v]) return y else: kv = [] y = mtf.layers.masked_local_attention_1d(x, self.kv_dim, self.heads_dim, self.window_size, params=params, return_kv=kv) if context.mode == "first_part": k = kv[0] v = kv[1] window_dim = mtf.Dimension("window", self.window_size) mesh = k.mesh window_pos = mtf.range(mesh, window_dim, tf.int32) pos = mtf.range(mesh, context.length_dim, tf.int32) select_recent = mtf.cast( mtf.equal(window_pos, mtf.mod(pos, self.window_size)), k.dtype) select_recent *= mtf.cast( mtf.less(pos, context.initial_position), k.dtype) select_recent *= mtf.cast( mtf.greater_equal( pos, context.initial_position - self.window_size), k.dtype) state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim] k_state = mtf.einsum([k, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) v_state = mtf.einsum([v, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.extend([k_state, v_state]) return y
def call(self, context, x, losses=None): """Call the layer.""" params = self.make_params(context) q = params.compute_q(x) if self.shared_kv: kv = params.compute_kv(x) k = kv v = kv else: k = params.compute_k(x) v = params.compute_v(x) if context.mode == "incremental": if self.shared_kv: prev_kv, = context.get_states(1) else: prev_k, prev_v = context.get_states(2) current_position = mtf.equal( mtf.range(context.mesh, self.window_dim, dtype=tf.int32), mtf.mod(context.position, self.radius)) if self.shared_kv: kv = mtf.where(current_position, kv, prev_kv, output_shape=prev_kv.shape) k = kv v = kv context.record_new_states([kv]) else: k = mtf.where(current_position, params.compute_k(x), prev_k, output_shape=prev_k.shape) v = mtf.where(current_position, params.compute_v(x), prev_v, output_shape=prev_v.shape) context.record_new_states([k, v]) window_pos = mtf.range(context.mesh, self.window_dim, tf.int32) visible = mtf.greater_equal(context.position, window_pos) bias = attention.visibility_mask_to_attention_bias( visible, context.activation_dtype) o = attention.attention( q, k, v, self.window_dim, self.kv_dim, self.kv_dim, bias, **self.attention_kwargs_from_context(context)) elif context.length_dim.size <= max(256, self.radius * 4): # nothing fancy - just do full attention and mask memory_length = self.rename_length_to_memory_length( context.position, context) o = attention.attention( q, self.rename_length_to_memory_length(k, context), self.rename_length_to_memory_length(v, context), self.memory_length(context), self.kv_dim, self.kv_dim, self.compute_bias(context, memory_length, x), **self.attention_kwargs_from_context(context)) else: # fancy local attention algorithm o = attention.local_attention_1d( q=q, k=k, v=None if self.shared_kv else v, length_dim=context.length_dim, key_dim=self.kv_dim, value_dim=self.kv_dim, length_dim_num_splits=1, # TODO(noam): look at the layout autoregressive=context.model.fully_autoregressive, radius=self.radius, sequence_id=context.sequence_id, write_priority=context.write_priority, read_priority=context.read_priority, attention_kwargs=self.attention_kwargs_from_context(context)) if context.mode == "first_part": window_pos = mtf.range(context.mesh, self.window_dim, tf.int32) pos = mtf.range(context.mesh, context.length_dim, tf.int32) select_recent = mtf.cast( mtf.equal(mtf.mod(pos, self.radius), window_pos), x.dtype) select_recent *= mtf.cast(mtf.less(pos, context.initial_position), x.dtype) select_recent *= mtf.cast( mtf.greater_equal(pos, context.initial_position - self.radius), x.dtype) state_shape = (k.shape - [context.length_dim, self.kv_dim] + [self.window_dim, self.kv_dim]) k_state = mtf.einsum([k, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.append(k_state) if not self.shared_kv: v_state = mtf.einsum([v, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.append(v_state) return params.compute_output(o, output_shape=x.shape)