def sample_categorical(x, dim=None): dim = x.shape[-1] if dim is None else dim cdf = mtf.cumsum(x, dim) rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1) mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32) return mtf.argmax(mask, dim)
def compute_mask(self, context, memory_position): """Compute attention mask. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. Returns: a Tensor or None """ masks = [] min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) if max_relative_position is not None or min_relative_position is not None: relative_position = memory_position - context.position if min_relative_position is not None: illegal = mtf.less(relative_position, min_relative_position) masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9) if max_relative_position is not None: illegal = mtf.greater(relative_position, max_relative_position) masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append(mtf.cast( mtf.not_equal( context.sequence_id, self.rename_length_to_memory_length( context.sequence_id, context)), context.activation_dtype) * -1e9) return mtf.add_n(masks) if masks else None
def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha = 1.3, dim = None, n_iter = 50): x, = explicit_inputs y, = outputs dY, = output_grads gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y)) dX = dY * gppr q = mtf.reduce_sum(dX, reduced_dim = dim) / mtf.reduce_sum(gppr, reduced_dim = dim) dX = dX - q * gppr return dX,
def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done)
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer(noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer( noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def compute_mask(self, context, memory_position): """Compute attention mask. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. Returns: a Tensor or None """ masks = [] min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) if max_relative_position is not None or min_relative_position is not None: relative_position = memory_position - context.position if min_relative_position is not None: illegal = mtf.less(relative_position, min_relative_position) masks.append( mtf.cast(illegal, context.activation_dtype) * -1e9) if max_relative_position is not None: illegal = mtf.greater(relative_position, max_relative_position) masks.append( mtf.cast(illegal, context.activation_dtype) * -1e9) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( sequence_id, self.rename_length_to_memory_length( sequence_id, context)), context.activation_dtype) * -1e9) return mtf.add_n(masks) if masks else None
def _top_2_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="top_2_gating"): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned combine_tensor is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of the tensors are as follows. inputs: [<batch_dims>, group_size_dim, input_dim] importance: [<batch_dims>, group_size_dim] dispatch_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] expert_inputs: [<batch_dims>, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim] combine_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] outputs: [<batch_dims>, group_size_dim, output_dim] "importance" is an optional tensor with one floating-point value for each input vector. If the importance of an input is 1.0, then we send it to up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most one expert. If importance == 0.0, then we send it to no experts. We use "importance" at the second-level gating function of a hierarchical mixture of experts. Inputs to the first-choice expert-group get importance 1.0. Inputs to the second-choice expert group get importance 0.5. Inputs that represent padding get importance 0.0. Args: inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim] outer_expert_dims: an optional list of dimensions. This is for the case where we are at an inner level of a hierarchical MoE. experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) hparams: model hyperparameters. train: a boolean variable_dtype: a mtf.VariableDType importance: an optional tensor with shape [<batch_dims>, group_size_dim] name: an optional string Returns: dispatch_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] combine_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ group_size_dim, unused_input_dim = inputs.shape.dims[-2:] raw_gates = mtf.layers.dense(inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(raw_gates, experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) density_1_proxy = raw_gates if importance is not None: mask_1 *= mtf.to_float(mtf.equal(importance, 1.0)) gate_1 *= mtf.to_float(mtf.equal(importance, 1.0)) density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0)) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if importance is not None: mask_2 *= mtf.to_float(mtf.greater(importance, 0.0)) denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. if train: policy = hparams.moe_second_policy_train threshold = hparams.moe_second_threshold_train else: policy = hparams.moe_second_policy_eval threshold = hparams.moe_second_threshold_eval if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def transformer_moe_layer_v2(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None): """2-level mixture of experts. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_capacity_factor_second_level: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() One set of params for experts in first level and different of hparams per expert in the second level. The number of parameters in the gating network is: (input_dim.size * (hparams.num_experts) + (moe_hidden_size * hparams.num_experts) * hparams.num_experts The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-3 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: a, b: batch size l: original sequence length m: input depth n: output depth g, h: number of groups s, t: group size x, y: number of experts c, d: expert capacity input: [a0, b1, l, m] input: [a0, g1, s, m] dispatch_tensor_x: [a0, g1, s, x, c] expert_input: [a0, g1, x, c, m] alltoall: [a0, g, x1, c, m] alltoall: [a0, g, x1, c, m] transpose: [x1, a0, g, c, m] reshape: [x1, h0, s, m] assignment2: [x1, h0, t, y, d] expert_input2: [x1, h0, y, d, m] alltoall: [x1, h, y0, d, m] ... reverse of that gating params 0: [m, x] gating params 1: [x1, m, y] expert params: [x1, y0, m, hidden] [x1, y0, hidden, n] Args: inputs: a mtf.Tensor with shape [a, b, l, m] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional mtf.Tensor with shape [a, b, l] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). Returns: outputs: a Tensor with shape [a, b, l, n] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1], dtype=inputs.dtype) + nonpadding insert_outer_batch_dim = (len(inputs.shape.dims) == 3) if insert_outer_batch_dim: inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims) assert len(hparams.moe_num_experts) == 2 a0, b1, l, m = inputs.shape.dims hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0]) y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1]) x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0]) y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1]) n = output_dim # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (g.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( b1.size * l.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1)) g1 = mtf.Dimension(b1.name, num_groups) g = mtf.Dimension(b1.name + "_unsplit", g1.size) s = mtf.Dimension("group_size_x", group_size) # Each sequence sends (at most?) expert_capacity positions to each expert. # Static expert_capacity dimension is needed for expert batch sizes if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size)) expert_capacity = max(expert_capacity, 4) c = mtf.Dimension("expert_capacity_x", expert_capacity) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (h.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( a0.size * g.size * c.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0)) t = mtf.Dimension("group_size_y", group_size) h0 = mtf.Dimension(a0.name, num_groups) h = mtf.Dimension(a0.name + "_unsplit", h0.size) expert_capacity = min( t.size, int((t.size * hparams.moe_capacity_factor_second_level) / y.size)) expert_capacity = max(expert_capacity, 4) d = mtf.Dimension("expert_capacity_y", expert_capacity) # First level of expert routing # Reshape the inner batch size to a multiple of group_dim g1 and # group_size_dim s. inputs = mtf.reshape(inputs, [a0, g1, s, m]) if nonpadding is not None: nonpadding = mtf.reshape(nonpadding, [a0, g1, s]) # Get the assignments for the first level. # dispatch_tensor_x has shape [a0, g1, s, x, c] if hparams.moe_gating == "top_2": dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=x, expert_capacity_dim=c, hparams=hparams, train=train, variable_dtype=variable_dtype, name="outer_gating", importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m]) # we construct an "importance" Tensor for the inputs to the second-level # gating. The importance of an input is 1.0 if it represents the # first-choice expert-group and 0.5 if it represents the second-choice expert # group. This is used by the second-level gating. importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c]) importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) + mtf.to_float(mtf.greater(importance, 0.0))) # First level, all to all. Here we change the split dimension from g1 to x1. expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c, m])) importance = mtf.reshape(importance, [x1, a0, g, c]) # Second level of expert routing # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0 # and group_size_dim t. inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m]) importance = mtf.reshape(importance, [x1, h0, t]) # Get the assignments for the second level. # dispatch_tensor_y has shape [x1, h0, t, y, d] if hparams.moe_gating == "top_2": dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating( inputs=inputs_y, outer_expert_dims=[x1], experts_dim=y, expert_capacity_dim=d, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=importance, name="inner_gating") else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m]) # Second level, all to all. Here we change the split dimension from h0 to y0. expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d, m])) hidden_output = mtf.layers.dense(expert_inputs_y, hidden_dim, expert_dims=[y0, x1], activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype, name="wi") expert_output = mtf.layers.dense(hidden_output, output_dim, expert_dims=[y0, x1], use_bias=False, variable_dtype=variable_dtype, name="wo") # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done) # expert_output has shape [y0, x1, h, d, n] # alltoall expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n])) # combine results from inner level output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n]) # Reshape the combined tensor from inner level to now contain outer_batch_dim # a0 and group_dim g output = mtf.reshape(output_y, [x1, a0, g, c, n]) # alltoall from expert_dim x to group_dim g1 expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n])) # combine results from outer level output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n]) # Reshape the combined tensor to now contain inner_batch_dim # b1 and the original sequence length output = mtf.reshape(output_x, [a0, b1, l, n]) if insert_outer_batch_dim: output = mtf.reshape(output, [b1, l, n]) return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, fully_autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, write_priority=None, read_priority=None, attention_kwargs=None): """Attention to the a neighborood around the source. If fully_autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not fully_autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. In addition, if write_priority and read_priority are provided, then attention is limited to position pairs where read_priority[query position] >= write_priority[memory position] Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) fully_autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer write_priority: an optional Tensor containing length_dim read_priority: an optional Tensor containing length_dim attention_kwargs: optional keyword arguments for attention() Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if fully_autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos visible = mtf.equal(q_sequence_id, m_sequence_id) visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius)) visible = mtf.logical_and(visible, mtf.less_equal( relative_position, 0 if fully_autoregressive else radius)) if read_priority is not None: write_priority = _reshape_memory(write_priority) read_priority = _reshape_query(read_priority) visible = mtf.logical_and( visible, mtf.greater_equal(read_priority, write_priority)) bias = visibility_mask_to_attention_bias(visible, q.dtype) o = attention(q, k, v, padded_memory_block_length, key_dim, value_dim, bias, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
def ut_function(state, step, halting_probability, remainders, n_updates, previous_state): """implements act (position-wise halting). Args: state: 3-D Tensor: [batch_size, length, channel] step: indicates number of steps taken so far halting_probability: halting probability remainders: act remainders n_updates: act n_updates previous_state: previous state Returns: transformed_state: transformed state step: step+1 halting_probability: halting probability remainders: act remainders n_updates: act n_updates new_state: new state """ state = self.step_preprocess(context, state, step) if self.act_type == "random": # random as halting probability p = mtf.random_uniform(context.mesh, shape=halting_probability.shape.dims, dtype=context.variable_dtype) else: last_dim_name = state.shape.dimension_names[-1] new_dims = [mtf.Dimension(last_dim_name, 1)] with tf.variable_scope("sigmoid_activation_for_pondering", reuse=tf.AUTO_REUSE): p = mtf.layers.dense(state, variable_dtype=context.variable_dtype, reduced_dims=[state.shape.dims[-1]], new_dims=new_dims, activation=mtf.sigmoid, use_bias=True) if self.act_type == "global": # average over all positions (as a global halting prob) p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1]) p = mtf.squeeze(p) else: # maintain position-wise probabilities new_shape = p.shape.dims[:-1] p = mtf.reshape(p, new_shape) # Mask for inputs which have not halted yet still_running = mtf.cast(mtf.less(halting_probability, 1.0), context.activation_dtype) # Mask of inputs which halted at this step new_halted = mtf.cast( mtf.greater(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Mask of inputs which haven't halted, and didn't halt this step still_running = mtf.cast( mtf.less_equal(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Add the halting probability for this step to the halting # probabilities for those input which haven't halted yet halting_probability += p * still_running # Compute remainders for the inputs which halted at this step remainders += new_halted * (1 - halting_probability) # Add the remainders to those inputs which halted at this step halting_probability += new_halted * remainders # Increment n_updates for all inputs which are still running n_updates += still_running + new_halted # Compute the weight to be applied to the new state and output # 0 when the input has already halted # p when the input hasn't halted yet # the remainders when it halted this step input_tensor = p * still_running + new_halted * remainders update_weights = input_tensor # apply transformation on the state transformed_state = state for _ in range(self.num_inrecurrence_layers): transformed_state = self.vanilla_transformer_layer( context, transformed_state, mask) # update running part in the weighted state and keep the rest new_state = ((transformed_state * update_weights) + (previous_state * (1 - update_weights))) if self.act_type == "accumulated": # Add in the weighted state new_state = (transformed_state * update_weights) + previous_state step += 1 return (transformed_state, step, halting_probability, remainders, n_updates, new_state)
def sample_autoregressive( partial_sequences, other_features, params, stop_at_token=50256, max_steps=None, temperature=0.9, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, bos_id=50256, ): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. bos_id: beginning of sequence id Returns: a Tensor with shape [<batch_dims>, length_dim] """ inputs = partial_sequences # Partial sequences to fill in batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] padding_id = params.get("padding_id", 0) slow_sampling = params.get("slow_sampling", False) initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts length_range = mtf.range(inputs.mesh, length_dim, tf.int32) input_full_attention = True # for now hardcode this to true bc lazy if input_full_attention: # Vanilla autoregressive model - each position can see previous positions. # Think this feeds in to the loop fn and tells each position where it can attend to? read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range # Builds context to pass around internally # The 'first part' context records initial states of k / v / x if not slow_sampling: context_first_part = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) with tf.variable_scope("gpt2"): logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part) if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] else: initial_states = context_first_part.new_states else: initial_states = [] if not has_partial_sequences: partial_sequences_eos_count = 0 if stop_at_token is not None: partial_sequences_eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # Remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, padding_id)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs
def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1)
def sample_autoregressive(self, partial_sequences, dst_attributes=None, stop_at_token=1, max_steps=None, temperature=0.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, z=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). The dst_attributes represents the destination attributes in which we want to generate sequences. Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. Returns: a Tensor with shape [<batch_dims>, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences attributes = dst_attributes batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None length_range = mtf.range(inputs.mesh, length_dim, tf.int32) if self.input_full_attention: read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] partial_sequences_eos_count = 0 else: initial_states = context_first_part.new_states partial_sequences_eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, position - 1, length_dim) else: attributes_this_step = None # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim)) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, 0)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs