def compute_bias(self, context, memory_position, x): """Compute attention bias. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. x: a Tensor - the query antecedent - required for relative attention Returns: a Tensor or None """ min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) # we can often cache the result of this function between similar layers can_cache = (self.relative_attention_type is None or self.relative_attention_type == "bias_shared") if can_cache: cache_key = ("self_attention_mask", min_relative_position, max_relative_position, self.relative_attention_type, self.num_heads) if cache_key in context.cache: return context.cache[cache_key] biases = [] relative_position = memory_position - context.position if min_relative_position is not None: visible = mtf.greater_equal(relative_position, min_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if max_relative_position is not None: visible = mtf.less_equal(relative_position, max_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if context.read_priority is not None: visible = mtf.greater_equal( context.read_priority, mtf.layers.rename_length_to_memory_length( context.write_priority)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): visible = mtf.equal( sequence_id, self.rename_length_to_memory_length(sequence_id, context)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if self.relative_attention_type is not None: buckets_dim = mtf.Dimension("buckets", self.relative_attention_num_buckets) heads_dim = mtf.Dimension("heads", self.num_heads) bidirectional = not context.model.fully_autoregressive rp_bucket = _relative_position_bucket(relative_position, bidirectional=bidirectional, num_buckets=buckets_dim.size) if (self.relative_attention_type == "bias" or self.relative_attention_type == "bias_shared"): values = mtf.get_variable(context.mesh, "relative_attention_bias", [heads_dim, buckets_dim], dtype=context.variable_dtype) elif self.relative_attention_type == "contextual": values = layers.dense(x, [buckets_dim, heads_dim], variable_dtype=context.variable_dtype, name="relative_attention_contextual") else: raise ValueError( "unrecognized relative_attention_type \"%s\"" % self.relative_attention_type) biases.append(mtf.gather(values, rp_bucket, buckets_dim)) ret = mtf.add_n(biases) if biases else None if can_cache: context.cache[cache_key] = ret return ret
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, fully_autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, write_priority=None, read_priority=None, attention_kwargs=None): """Attention to the a neighborood around the source. If fully_autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not fully_autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. In addition, if write_priority and read_priority are provided, then attention is limited to position pairs where read_priority[query position] >= write_priority[memory position] Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) fully_autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer write_priority: an optional Tensor containing length_dim read_priority: an optional Tensor containing length_dim attention_kwargs: optional keyword arguments for attention() Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if fully_autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos visible = mtf.equal(q_sequence_id, m_sequence_id) visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius)) visible = mtf.logical_and(visible, mtf.less_equal( relative_position, 0 if fully_autoregressive else radius)) if read_priority is not None: write_priority = _reshape_memory(write_priority) read_priority = _reshape_query(read_priority) visible = mtf.logical_and( visible, mtf.greater_equal(read_priority, write_priority)) bias = visibility_mask_to_attention_bias(visible, q.dtype) o = attention(q, k, v, padded_memory_block_length, key_dim, value_dim, bias, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
def ut_function(state, step, halting_probability, remainders, n_updates, previous_state): """implements act (position-wise halting). Args: state: 3-D Tensor: [batch_size, length, channel] step: indicates number of steps taken so far halting_probability: halting probability remainders: act remainders n_updates: act n_updates previous_state: previous state Returns: transformed_state: transformed state step: step+1 halting_probability: halting probability remainders: act remainders n_updates: act n_updates new_state: new state """ state = self.step_preprocess(context, state, step) if self.act_type == "random": # random as halting probability p = mtf.random_uniform(context.mesh, shape=halting_probability.shape.dims, dtype=context.variable_dtype) else: last_dim_name = state.shape.dimension_names[-1] new_dims = [mtf.Dimension(last_dim_name, 1)] with tf.variable_scope("sigmoid_activation_for_pondering", reuse=tf.AUTO_REUSE): p = mtf.layers.dense(state, variable_dtype=context.variable_dtype, reduced_dims=[state.shape.dims[-1]], new_dims=new_dims, activation=mtf.sigmoid, use_bias=True) if self.act_type == "global": # average over all positions (as a global halting prob) p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1]) p = mtf.squeeze(p) else: # maintain position-wise probabilities new_shape = p.shape.dims[:-1] p = mtf.reshape(p, new_shape) # Mask for inputs which have not halted yet still_running = mtf.cast(mtf.less(halting_probability, 1.0), context.activation_dtype) # Mask of inputs which halted at this step new_halted = mtf.cast( mtf.greater(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Mask of inputs which haven't halted, and didn't halt this step still_running = mtf.cast( mtf.less_equal(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Add the halting probability for this step to the halting # probabilities for those input which haven't halted yet halting_probability += p * still_running # Compute remainders for the inputs which halted at this step remainders += new_halted * (1 - halting_probability) # Add the remainders to those inputs which halted at this step halting_probability += new_halted * remainders # Increment n_updates for all inputs which are still running n_updates += still_running + new_halted # Compute the weight to be applied to the new state and output # 0 when the input has already halted # p when the input hasn't halted yet # the remainders when it halted this step input_tensor = p * still_running + new_halted * remainders update_weights = input_tensor # apply transformation on the state transformed_state = state for _ in range(self.num_inrecurrence_layers): transformed_state = self.vanilla_transformer_layer( context, transformed_state, mask) # update running part in the weighted state and keep the rest new_state = ((transformed_state * update_weights) + (previous_state * (1 - update_weights))) if self.act_type == "accumulated": # Add in the weighted state new_state = (transformed_state * update_weights) + previous_state step += 1 return (transformed_state, step, halting_probability, remainders, n_updates, new_state)
def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret
def _switch_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="switch_gating", num_microbatches=None): """Compute a switch top-1 gating with no-token-left behind behavior.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # Input perturbations if train and policy == "input_jitter": inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) # Top-k operation k_dim = mtf.Dimension("k", hparams.moe_switch_top_k) expert_gate, expert_index = mtf.top_k( raw_gates, reduced_dim=experts_dim, k_dim=k_dim) expert_mask = mtf.one_hot(expert_index, experts_dim) # LOAD BALANCING LOSS outer_batch_dim = inputs.shape[0] batch_dim = inputs.shape[1] group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum( -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Iteratively route tokens (no-token-left-behind). The idea is to route as # many tokens as possible to top-i before then trying top-(i+1). top_k_masks = mtf.split( expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_gates = mtf.split( expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_indices = mtf.split( expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size) # Tensors cumulative values over the iterative process. combine_tensor = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim]) cum_tokens = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim]) tokens_left_to_route = mtf.constant( inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim]) expert_capacity_float = float(expert_capacity_dim.size) for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates, top_k_indices): top_i_mask = mtf.reshape( top_i_mask, new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim]) # Operate only on the unrouted tokens. top_i_mask *= tokens_left_to_route # Record cumulative number of tokens to each expert across iterations. cumulative_tokens_in_expert = cum_tokens + mtf.cumsum( top_i_mask, group_size_dim) expert_overflow = mtf.to_float( mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float)) output_i_tokens = top_i_mask * expert_overflow # Update the cumulative tokens routed to each expert. cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim) tokens_left_to_route -= ( mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim)) # Combine-tensor for this iteration output_i_tokens_flat = mtf.reduce_sum( output_i_tokens, reduced_dim=experts_dim) position_in_expert = cumulative_tokens_in_expert - 1 top_i_combine_tensor = ( top_i_gate * output_i_tokens_flat * mtf.one_hot(top_i_index, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim)) combine_tensor += top_i_combine_tensor # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, position - 1, length_dim) else: attributes_this_step = None # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim)) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states