def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha = 1.3, dim = None, n_iter = 50): x, = explicit_inputs y, = outputs dY, = output_grads gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y)) dX = dY * gppr q = mtf.reduce_sum(dX, reduced_dim = dim) / mtf.reduce_sum(gppr, reduced_dim = dim) dX = dX - q * gppr return dX,
def beam_search(self, inputs, decode_length, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]. decode_length: an int32 mtf scalar. Maximum decode length. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather( beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def sample_autoregressive(self, partial_sequences, stop_at_token=1, max_steps=None, temperature=1.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, length_dim] """ del max_steps # TODO(noam): implement if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states] else: initial_states = context_first_part.new_states def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) is_done = past_end if stop_at_token is not None: has_eos = mtf.reduce_any( mtf.equal(ids, stop_at_token), reduced_dim=length_dim) is_done = mtf.logical_or(is_done, has_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position return outputs
def _top_2_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="top_2_gating"): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned combine_tensor is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of the tensors are as follows. inputs: [<batch_dims>, group_size_dim, input_dim] importance: [<batch_dims>, group_size_dim] dispatch_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] expert_inputs: [<batch_dims>, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim] combine_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] outputs: [<batch_dims>, group_size_dim, output_dim] "importance" is an optional tensor with one floating-point value for each input vector. If the importance of an input is 1.0, then we send it to up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most one expert. If importance == 0.0, then we send it to no experts. We use "importance" at the second-level gating function of a hierarchical mixture of experts. Inputs to the first-choice expert-group get importance 1.0. Inputs to the second-choice expert group get importance 0.5. Inputs that represent padding get importance 0.0. Args: inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim] outer_expert_dims: an optional list of dimensions. This is for the case where we are at an inner level of a hierarchical MoE. experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) hparams: model hyperparameters. train: a boolean variable_dtype: a mtf.VariableDType importance: an optional tensor with shape [<batch_dims>, group_size_dim] name: an optional string Returns: dispatch_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] combine_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ group_size_dim, unused_input_dim = inputs.shape.dims[-2:] raw_gates = mtf.layers.dense(inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(raw_gates, experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) density_1_proxy = raw_gates if importance is not None: mask_1 *= mtf.to_float(mtf.equal(importance, 1.0)) gate_1 *= mtf.to_float(mtf.equal(importance, 1.0)) density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0)) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if importance is not None: mask_2 *= mtf.to_float(mtf.greater(importance, 0.0)) denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. if train: policy = hparams.moe_second_policy_train threshold = hparams.moe_second_threshold_train else: policy = hparams.moe_second_policy_eval threshold = hparams.moe_second_threshold_eval if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None): """A GPT style model implemented in mesh tensorflow.""" x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features) if is_incremental_inference(context): # reshape inputs if in inference mode x = mtf.gather(x, context.position - 1, sequence_dim) x = mtf.reshape(x, [batch_dim]) use_axial_pos_emb = params["axial_pos_emb"] is not None if not use_axial_pos_emb: # Use standard position encoding wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) else: wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype) # Text encoding wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.02), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) with tf.variable_scope("token_embd"): # Text embedding h = mtf.gather(wte, x, vocab_dim) if params["embed_dropout"] > 0 and params["mode"] == "train": h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout") with tf.variable_scope("pos_embd"): # Positional embedding position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else ( context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) if params["embed_dropout"] > 0 and params["mode"] == "train": pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout") h += pos_emb aux_losses = 0 # instantiate auxiliary losses (for MOE models) for layer in range(params["n_layer"]): # attn blocks share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True block_scope = f"h{layer}" if not share_parameters else "" block_fn = block(params=params, scope=block_scope, layer_num=layer, bias=other_features["attn_bias"], sequence_dim=sequence_dim, memory_length_dim=other_features["memory_length_dim"], variable_dtype=variable_dtype, context=context) # If true and in train mode, enable gradient checkpointing recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h]) aux_losses += loss no_weight_tie_emb = params["no_weight_tie"] == True if no_weight_tie_emb: with tf.variable_scope("wte_final_linear"): logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params) else: # Layer normalize & affine transform h = layer_norm(h, "ln_f", variable_dtype=variable_dtype) seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1) with tf.variable_scope("wte_final_einsum"): # Equivalent to tf.matmul logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim]) if params["mode"] in ["train", "eval"]: labels = mtf_features["labels"] z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy # Go to full precision for the logits logits = mtf.cast(logits, tf.float32) use_entmax_loss = params.get("entmax_loss", False) loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits with tf.variable_scope("xentropy_final"): loss_batch = loss_fn(logits=logits, targets=labels, vocab_dim=logits.shape[-1], z_loss=z_loss) # For non-autoregressive models (masked language modeling training) # Make sure labels with padding tokens are not counted in the loss if not params["causal"]: padding_id = params.get("padding_id", 0) loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch)) with tf.variable_scope("reduce_mean_final"): loss = mtf.reduce_mean(loss_batch) loss += aux_losses # Add on auxiliary losses (currently only used for MoE) loss /= params["num_microbatches"] # Convert to train dtype loss = mtf.cast(loss, variable_dtype.slice_dtype) else: loss = None loss_batch = None # Cast back to checkpoint dtype logits = mtf.cast(logits, variable_dtype.master_dtype) return logits, loss, loss_batch
def act_layer(self, context, x, mask): """Build a Universal Transformer ACT layer.""" state = x act_max_steps = self.act_max_steps threshold = 1.0 - self.act_epsilon state_shape_static = state.shape.dims state_slice = slice(0, 3) if self.act_type == "global": state_slice = slice(0, 2) # Dynamic shape for update tensors below update_shape = state_shape_static[state_slice] # Halting probabilities (p_t^n in the paper) halting_probability = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Remainders (R(t) in the paper) remainders = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Number of updates performed (N(t) in the paper) n_updates = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Previous cell states (s_t in the paper) previous_state = mtf.zeros_like(state) step = mtf.constant(context.mesh, 0, dtype=tf.int32) def ut_function(state, step, halting_probability, remainders, n_updates, previous_state): """implements act (position-wise halting). Args: state: 3-D Tensor: [batch_size, length, channel] step: indicates number of steps taken so far halting_probability: halting probability remainders: act remainders n_updates: act n_updates previous_state: previous state Returns: transformed_state: transformed state step: step+1 halting_probability: halting probability remainders: act remainders n_updates: act n_updates new_state: new state """ state = self.step_preprocess(context, state, step) if self.act_type == "random": # random as halting probability p = mtf.random_uniform(context.mesh, shape=halting_probability.shape.dims, dtype=context.variable_dtype) else: last_dim_name = state.shape.dimension_names[-1] new_dims = [mtf.Dimension(last_dim_name, 1)] with tf.variable_scope("sigmoid_activation_for_pondering", reuse=tf.AUTO_REUSE): p = mtf.layers.dense(state, variable_dtype=context.variable_dtype, reduced_dims=[state.shape.dims[-1]], new_dims=new_dims, activation=mtf.sigmoid, use_bias=True) if self.act_type == "global": # average over all positions (as a global halting prob) p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1]) p = mtf.squeeze(p) else: # maintain position-wise probabilities new_shape = p.shape.dims[:-1] p = mtf.reshape(p, new_shape) # Mask for inputs which have not halted yet still_running = mtf.cast(mtf.less(halting_probability, 1.0), context.activation_dtype) # Mask of inputs which halted at this step new_halted = mtf.cast( mtf.greater(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Mask of inputs which haven't halted, and didn't halt this step still_running = mtf.cast( mtf.less_equal(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Add the halting probability for this step to the halting # probabilities for those input which haven't halted yet halting_probability += p * still_running # Compute remainders for the inputs which halted at this step remainders += new_halted * (1 - halting_probability) # Add the remainders to those inputs which halted at this step halting_probability += new_halted * remainders # Increment n_updates for all inputs which are still running n_updates += still_running + new_halted # Compute the weight to be applied to the new state and output # 0 when the input has already halted # p when the input hasn't halted yet # the remainders when it halted this step input_tensor = p * still_running + new_halted * remainders update_weights = input_tensor # apply transformation on the state transformed_state = state for _ in range(self.num_inrecurrence_layers): transformed_state = self.vanilla_transformer_layer( context, transformed_state, mask) # update running part in the weighted state and keep the rest new_state = ((transformed_state * update_weights) + (previous_state * (1 - update_weights))) if self.act_type == "accumulated": # Add in the weighted state new_state = (transformed_state * update_weights) + previous_state step += 1 return (transformed_state, step, halting_probability, remainders, n_updates, new_state) for _ in range(act_max_steps + 1): (state, step, halting_probability, remainders, n_updates, previous_state) = ut_function(state, step, halting_probability, remainders, n_updates, previous_state) ponder_times = n_updates mtf.scalar_summary("ponder_times", mtf.reduce_mean(ponder_times)) return previous_state
def sample_autoregressive( partial_sequences, other_features, params, stop_at_token=50256, max_steps=None, temperature=0.9, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, bos_id=50256, ): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. bos_id: beginning of sequence id Returns: a Tensor with shape [<batch_dims>, length_dim] """ inputs = partial_sequences # Partial sequences to fill in batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] padding_id = params.get("padding_id", 0) slow_sampling = params.get("slow_sampling", False) initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts length_range = mtf.range(inputs.mesh, length_dim, tf.int32) input_full_attention = True # for now hardcode this to true bc lazy if input_full_attention: # Vanilla autoregressive model - each position can see previous positions. # Think this feeds in to the loop fn and tells each position where it can attend to? read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range # Builds context to pass around internally # The 'first part' context records initial states of k / v / x if not slow_sampling: context_first_part = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) with tf.variable_scope("gpt2"): logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part) if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] else: initial_states = context_first_part.new_states else: initial_states = [] if not has_partial_sequences: partial_sequences_eos_count = 0 if stop_at_token is not None: partial_sequences_eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # Remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, padding_id)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs
def beam_search(self, inputs, decode_length, dst_attributes=None, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None, z=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim].# decode_length: an int32 mtf scalar. Maximum decode length. attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim] ([<batch_dims>] [<batch_dims>, beam_dim]). variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ attributes = dst_attributes if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, tf.int32) initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None if self.input_full_attention: # This only makes sense in the case of beam search with given partial # sequences, which is not yet implemented. # TODO(noam): implement raise NotImplementedError( "Beam search for language models not yet implemented") else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, step_num - 1, length_dim) else: attributes_this_step = None context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=step_num, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def sample_autoregressive(self, partial_sequences, dst_attributes=None, stop_at_token=1, max_steps=None, temperature=0.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, z=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). The dst_attributes represents the destination attributes in which we want to generate sequences. Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. Returns: a Tensor with shape [<batch_dims>, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences attributes = dst_attributes batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None length_range = mtf.range(inputs.mesh, length_dim, tf.int32) if self.input_full_attention: read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] partial_sequences_eos_count = 0 else: initial_states = context_first_part.new_states partial_sequences_eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, position - 1, length_dim) else: attributes_this_step = None # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim)) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, 0)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs