예제 #1
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = mtf.layers.multihead_attention_params(context.mesh,
                                                    self.heads_dim,
                                                    context.model_dim,
                                                    self.kv_dim,
                                                    context.variable_dtype)
     if context.mode == "incremental":
         prev_k, prev_v = context.get_states(2)
         y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
             x, prev_k, prev_v, context.position, params=params)
         context.record_new_states([new_k, new_v])
         return y
     else:
         kv = []
         y = mtf.layers.masked_local_attention_1d(x,
                                                  self.kv_dim,
                                                  self.heads_dim,
                                                  self.window_size,
                                                  params=params,
                                                  return_kv=kv)
         if context.mode == "first_part":
             k = kv[0]
             v = kv[1]
             window_dim = mtf.Dimension("window", self.window_size)
             mesh = k.mesh
             window_pos = mtf.range(mesh, window_dim, tf.int32)
             pos = mtf.range(mesh, context.length_dim, tf.int32)
             select_recent = mtf.cast(
                 mtf.equal(window_pos, mtf.mod(pos, self.window_size)),
                 k.dtype)
             select_recent *= mtf.cast(
                 mtf.less(pos, context.initial_position), k.dtype)
             select_recent *= mtf.cast(
                 mtf.greater_equal(
                     pos, context.initial_position - self.window_size),
                 k.dtype)
             state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim]
             k_state = mtf.einsum([k, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.extend([k_state, v_state])
         return y
예제 #2
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = self.make_params(context)
     q = params.compute_q(x)
     if self.shared_kv:
         kv = params.compute_kv(x)
         k = kv
         v = kv
     else:
         k = params.compute_k(x)
         v = params.compute_v(x)
     if context.mode == "incremental":
         if self.shared_kv:
             prev_kv, = context.get_states(1)
         else:
             prev_k, prev_v = context.get_states(2)
         current_position = mtf.equal(
             mtf.range(context.mesh, self.window_dim, dtype=tf.int32),
             mtf.mod(context.position, self.radius))
         if self.shared_kv:
             kv = mtf.where(current_position,
                            kv,
                            prev_kv,
                            output_shape=prev_kv.shape)
             k = kv
             v = kv
             context.record_new_states([kv])
         else:
             k = mtf.where(current_position,
                           params.compute_k(x),
                           prev_k,
                           output_shape=prev_k.shape)
             v = mtf.where(current_position,
                           params.compute_v(x),
                           prev_v,
                           output_shape=prev_v.shape)
             context.record_new_states([k, v])
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         visible = mtf.greater_equal(context.position, window_pos)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
         o = attention.attention(
             q, k, v, self.window_dim, self.kv_dim, self.kv_dim, bias,
             **self.attention_kwargs_from_context(context))
     elif context.length_dim.size <= max(256, self.radius * 4):
         # nothing fancy - just do full attention and mask
         memory_length = self.rename_length_to_memory_length(
             context.position, context)
         o = attention.attention(
             q, self.rename_length_to_memory_length(k, context),
             self.rename_length_to_memory_length(v, context),
             self.memory_length(context), self.kv_dim, self.kv_dim,
             self.compute_bias(context, memory_length, x),
             **self.attention_kwargs_from_context(context))
     else:
         # fancy local attention algorithm
         o = attention.local_attention_1d(
             q=q,
             k=k,
             v=None if self.shared_kv else v,
             length_dim=context.length_dim,
             key_dim=self.kv_dim,
             value_dim=self.kv_dim,
             length_dim_num_splits=1,  # TODO(noam): look at the layout
             autoregressive=context.model.fully_autoregressive,
             radius=self.radius,
             sequence_id=context.sequence_id,
             write_priority=context.write_priority,
             read_priority=context.read_priority,
             attention_kwargs=self.attention_kwargs_from_context(context))
     if context.mode == "first_part":
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         pos = mtf.range(context.mesh, context.length_dim, tf.int32)
         select_recent = mtf.cast(
             mtf.equal(mtf.mod(pos, self.radius), window_pos), x.dtype)
         select_recent *= mtf.cast(mtf.less(pos, context.initial_position),
                                   x.dtype)
         select_recent *= mtf.cast(
             mtf.greater_equal(pos, context.initial_position - self.radius),
             x.dtype)
         state_shape = (k.shape - [context.length_dim, self.kv_dim] +
                        [self.window_dim, self.kv_dim])
         k_state = mtf.einsum([k, select_recent],
                              output_shape=state_shape,
                              reduced_dims=[context.length_dim])
         context.new_states.append(k_state)
         if not self.shared_kv:
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.append(v_state)
     return params.compute_output(o, output_shape=x.shape)