def _mtf_model_fn(self, features, mesh): self._original_features = features hparams = self._hparams def import_feature(key): return self._import_feature(features, mesh, key) targets = import_feature("targets") inputs = import_feature("inputs") encoder_sequence_id = import_feature("inputs_segmentation") if not encoder_sequence_id: encoder_sequence_id = mtf.to_int32(mtf.not_equal(inputs, 0)) decoder_sequence_id = import_feature("targets_segmentation") if decoder_sequence_id is None: decoder_sequence_id = mtf.to_int32(mtf.not_equal(targets, 0)) model = self.model() logits, loss = model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=hparams.mode, variable_dtype=self.variable_dtype, encoder_sequence_id=encoder_sequence_id, decoder_sequence_id=decoder_sequence_id) return logits, loss
def _mtf_model_fn(self, features, mesh): self._original_features = features hparams = self._hparams def import_feature(key): return self._import_feature(features, mesh, key) targets = import_feature("targets") inputs = import_feature("inputs") if not inputs: raise ValueError("inputs feature is missing") encoder_sequence_id = import_feature("inputs_segmentation") if not encoder_sequence_id: encoder_sequence_id = mtf.to_int32(mtf.not_equal(inputs, 0)) decoder_sequence_id = import_feature("targets_segmentation") if decoder_sequence_id is None: decoder_sequence_id = mtf.to_int32(mtf.not_equal(targets, 0)) if hparams.use_global_position_in_packed_sequence: encoder_position = None decoder_position = None else: encoder_position = import_feature("inputs_position") decoder_position = import_feature("targets_position") model = self.model() logits, loss = model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=hparams.mode, variable_dtype=self.variable_dtype, encoder_sequence_id=encoder_sequence_id, decoder_sequence_id=decoder_sequence_id, encoder_position=encoder_position, decoder_position=decoder_position) return logits, loss
def compute_mask(self, context, memory_position): """Compute attention mask. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. Returns: a Tensor or None """ masks = [] min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) if max_relative_position is not None or min_relative_position is not None: relative_position = memory_position - context.position if min_relative_position is not None: illegal = mtf.less(relative_position, min_relative_position) masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9) if max_relative_position is not None: illegal = mtf.greater(relative_position, max_relative_position) masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append(mtf.cast( mtf.not_equal( context.sequence_id, self.rename_length_to_memory_length( context.sequence_id, context)), context.activation_dtype) * -1e9) return mtf.add_n(masks) if masks else None
def _mtf_model_fn(self, features, mesh): self._original_features = features hparams = self._hparams def import_feature(key): return self._import_feature(features, mesh, key) targets = import_feature("targets") sequence_id = import_feature("targets_segmentation") position = import_feature("targets_position") if self.autoregressive: inputs = mtf.shift(targets, offset=1, dim=self.length_dim, wrap=False) if position is not None: # first input in later sequences should be 0 inputs *= mtf.to_int32(mtf.not_equal(position, 0)) else: inputs = import_feature("inputs") # TODO(noam): options for bert-style masking here? model = self.model() logits, loss = model.call_simple(inputs=inputs, targets=targets, compute_loss=True, mode=hparams.mode, variable_dtype=self.variable_dtype, sequence_id=sequence_id, position=position) return logits, loss
def sample(self, features, mesh): hparams = self._hparams model = self.model() def import_feature(key): return self._import_feature(features, mesh, key) if self.autoregressive: # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = import_feature("inputs") if partial_targets is None: partial_targets = import_feature("targets") if partial_targets: partial_targets *= mtf.cast(mtf.not_equal(partial_targets, 1), partial_targets.dtype) else: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) if hparams.beam_size > 1: raise NotImplementedError( "Beam search not implemented for unitransformer.") ret = model.sample_autoregressive( partial_targets, temperature=hparams.sampling_temp, variable_dtype=self.variable_dtype) return self.combine_batch_dims(ret) else: raise ValueError( "Don't know how to sample from non-autoregressive unitransformer" )
def call(self, context, x): """Call the layer stack.""" if isinstance(context.sequence_id, mtf.Tensor): # We use this mask to zero out the padding regions at each layer. # This "fixes" a bug where extreme values leak from the padding into the # non-padding regions. # TODO(noam): understand this better and make a more principled fix. mask = mtf.cast(mtf.not_equal(context.sequence_id, 0), context.activation_dtype) else: mask = None x = self._dropout(context, x) context.layer_outputs.append(x) if self.mix_with_transformer_before_ut: for _ in range(self.num_vanilla_transformer_layers): x = self.vanilla_transformer_layer(context, x, mask) # Call a ACT layer if self.recurrence_type == "act": x = self.act_layer(context, x, mask) elif self.recurrence_type == "basic": x = self.ut_basic(context, x, mask) elif self.recurrence_type == "highway": layer_inputs = (x, x, x) x = self.ut_highway(context, layer_inputs, mask) if self.mix_with_transformer_after_ut: for _ in range(self.num_vanilla_transformer_layers): x = self.vanilla_transformer_layer(context, x, mask) x = self._layer_norm(context, x, name="final_layer_norm") x = self._dropout(context, x) if mask: x *= mask context.layer_outputs.append(x) return x
def call(self, context, x): """Call the layer stack.""" if isinstance(context.sequence_id, mtf.Tensor): # We use this mask to zero out the padding regions at each layer. # This "fixes" a bug where extreme values leak from the padding into the # non-padding regions. # TODO(noam): undertand this better and make a more principled fix. mask = mtf.cast(mtf.not_equal(context.sequence_id, 0), context.activation_dtype) else: mask = None x = self._dropout(context, x) if context.layer_outputs is not None: context.layer_outputs.append(x) for lnum, layer in enumerate(self._layers): with tf.variable_scope("layer_%03d" % lnum): norm_x = self._layer_norm(context, (x * mask) if mask else x) with tf.variable_scope(layer.__class__.__name__): y = layer.call(context, norm_x) if y.shape != x.shape: raise ValueError( "Layer %s returned misshaped output x=%s y=%s" % (layer.__class__.__name__, x, y)) x += self._dropout(context, y) if context.layer_outputs is not None and lnum != len( self._layers) - 1: context.layer_outputs.append(x) context.layer_index += 1 x = self._layer_norm(context, x, name="final_layer_norm") x = self._dropout(context, x) if mask: x *= mask if context.layer_outputs is not None: context.layer_outputs.append(x) return x
def nonpadding(self): """Tensor with zeros in padding positions and ones elsewhere.""" if self.sequence_id is None: return None if self.sequence_id == 1: return 1 else: return mtf.cast( mtf.not_equal(self.sequence_id, 0), self.activation_dtype)
def shift_targets_no_offset(targets, bos_id=0, eos_id=1): """Transforms decoder labels to decoder inputs. Args: targets: decoder labels bos_id: begin of sequence id, defaults to 0 eos_id: end of sequence id, defaults to 1 Returns: Decoder inputs. """ length_dim = targets.shape.dims[-1] shifted_targets = targets # We should have a 0 at the beginning of each sequence rather than the # shifted EOS (e.g. 1) from the previous sequence. shifted_targets *= mtf.to_int32(mtf.not_equal(shifted_targets, eos_id)) if bos_id: shifted_targets += mtf.to_int32( mtf.logical_and(mtf.equal(shifted_targets, 0), mtf.not_equal(targets, 0))) * bos_id return shifted_targets
def call(self, context, x, losses=None): """Call the layer.""" wq, wk, wv, wo = mtf.layers.multihead_attention_params( context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) memory_length = mtf.Dimension("memory_length", context.length_dim.size) q = mtf.einsum([x, wq], reduced_dims=[context.model_dim]) if context.mode == "incremental": m = x else: m = mtf.rename_dimension(x, context.length_dim.name, "memory_length") k = mtf.einsum([m, wk], reduced_dims=[context.model_dim]) v = mtf.einsum([m, wv], reduced_dims=[context.model_dim]) if context.mode == "incremental": old_k, old_v = context.get_states(2) one_hot = mtf.one_hot(context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([k, v]) masks = [] if context.autoregressive: masks.append( mtf.cast( mtf.less( context.position, mtf.range(context.mesh, memory_length, dtype=tf.int32)), context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( context.sequence_id, mtf.layers.rename_length_to_memory_length( context.sequence_id)), context.activation_dtype) * -1e9) mask = mtf.add_n(masks) if masks else None o = mtf.layers.dot_product_attention_v2( q, k, v, memory_length, self.kv_dim, self.kv_dim, mask, self.dropout_rate if context.train else 0.0, [context.length_dim]) return mtf.einsum([o, wo], x.shape, reduced_dims=[self.heads_dim, self.kv_dim])
def call(self, context, x, losses=None): """Call the layer.""" memory_input_dim = context.encoder_output.shape[-1] if memory_input_dim != context.model_dim: raise NotImplementedError( "TODO(noam): support different model_dim in encoder and decoder." ) wq, wk, wv, wo = mtf.layers.multihead_attention_params( context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) q = mtf.einsum([x, wq], reduced_dims=[context.model_dim]) if context.mode == "incremental": k, v, memory_length = context.get_constant_state() else: m = context.encoder_output memory_length, = [ d for d in m.shape.dims if d.name == "memory_length" ] k = mtf.einsum([m, wk], reduced_dims=[context.model_dim]) v = mtf.einsum([m, wv], reduced_dims=[context.model_dim]) if context.mode == "first_part": context.record_constant_state((k, v, memory_length)) if context.encoder_sequence_id and context.sequence_id: mask = mtf.cast( mtf.not_equal(context.sequence_id, context.encoder_sequence_id), context.activation_dtype) * -1e9 else: mask = None o = mtf.layers.dot_product_attention_v2( q, k, v, memory_length, self.kv_dim, self.kv_dim, mask, dropout=self.dropout_rate if context.train else 0.0, dropout_broadcast_dims=[context.length_dim]) return mtf.einsum([o, wo], x.shape, reduced_dims=[self.heads_dim, self.kv_dim])
def compute_mask(self, context, memory_position): """Compute attention mask. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. Returns: a Tensor or None """ masks = [] min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) if max_relative_position is not None or min_relative_position is not None: relative_position = memory_position - context.position if min_relative_position is not None: illegal = mtf.less(relative_position, min_relative_position) masks.append( mtf.cast(illegal, context.activation_dtype) * -1e9) if max_relative_position is not None: illegal = mtf.greater(relative_position, max_relative_position) masks.append( mtf.cast(illegal, context.activation_dtype) * -1e9) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( sequence_id, self.rename_length_to_memory_length( sequence_id, context)), context.activation_dtype) * -1e9) return mtf.add_n(masks) if masks else None
def sample(self, features, mesh): hparams = self._hparams model = self.model() # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, self.length_dim.size - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh) # strip EOS partial_targets *= mtf.to_int32(mtf.not_equal(partial_targets, 1)) else: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) if hparams.beam_size == 1: pass else: raise NotImplementedError("not implemented") # beam_dim = mtf.Dimension("beam", hparams.beam_size) # ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) partial_targets = mtf.Print(partial_targets, [partial_targets], "Partial_Targets", summarize=1000) return model.sample_autoregressive(partial_targets, temperature=hparams.sampling_temp, variable_dtype=self.variable_dtype)
def call(self, context, x, losses=None): """Call the layer.""" memory_input_dim = context.encoder_output.shape[-1] if memory_input_dim != context.model_dim: raise NotImplementedError( "TODO(noam): support different model_dim in encoder and decoder.") params = self.make_params(context) q = params.compute_q(x) if context.mode == "incremental": k, v, memory_length = context.get_constant_state() else: m = context.encoder_output if self.shared_kv: kv = params.compute_kv(m) k = kv v = kv else: k = params.compute_k(m) v = params.compute_v(m) memory_length, = [d for d in m.shape.dims if d.name == "memory_length"] if context.mode == "first_part": context.record_constant_state((k, v, memory_length)) if context.encoder_sequence_id and context.sequence_id: mask = mtf.cast( mtf.not_equal( context.sequence_id, context.encoder_sequence_id), context.activation_dtype) * -1e9 else: mask = None o = attention.attention( q, k, v, memory_length, self.kv_dim, self.kv_dim, mask, **self.attention_kwargs_from_context(context)) return params.compute_output(o, output_shape=x.shape)
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None): """A GPT style model implemented in mesh tensorflow.""" x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features) if is_incremental_inference(context): # reshape inputs if in inference mode x = mtf.gather(x, context.position - 1, sequence_dim) x = mtf.reshape(x, [batch_dim]) use_axial_pos_emb = params["axial_pos_emb"] is not None if not use_axial_pos_emb: # Use standard position encoding wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) else: wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype) # Text encoding wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.02), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) with tf.variable_scope("token_embd"): # Text embedding h = mtf.gather(wte, x, vocab_dim) if params["embed_dropout"] > 0 and params["mode"] == "train": h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout") with tf.variable_scope("pos_embd"): # Positional embedding position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else ( context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) if params["embed_dropout"] > 0 and params["mode"] == "train": pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout") h += pos_emb aux_losses = 0 # instantiate auxiliary losses (for MOE models) for layer in range(params["n_layer"]): # attn blocks share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True block_scope = f"h{layer}" if not share_parameters else "" block_fn = block(params=params, scope=block_scope, layer_num=layer, bias=other_features["attn_bias"], sequence_dim=sequence_dim, memory_length_dim=other_features["memory_length_dim"], variable_dtype=variable_dtype, context=context) # If true and in train mode, enable gradient checkpointing recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h]) aux_losses += loss no_weight_tie_emb = params["no_weight_tie"] == True if no_weight_tie_emb: with tf.variable_scope("wte_final_linear"): logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params) else: # Layer normalize & affine transform h = layer_norm(h, "ln_f", variable_dtype=variable_dtype) seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1) with tf.variable_scope("wte_final_einsum"): # Equivalent to tf.matmul logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim]) if params["mode"] in ["train", "eval"]: labels = mtf_features["labels"] z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy # Go to full precision for the logits logits = mtf.cast(logits, tf.float32) use_entmax_loss = params.get("entmax_loss", False) loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits with tf.variable_scope("xentropy_final"): loss_batch = loss_fn(logits=logits, targets=labels, vocab_dim=logits.shape[-1], z_loss=z_loss) # For non-autoregressive models (masked language modeling training) # Make sure labels with padding tokens are not counted in the loss if not params["causal"]: padding_id = params.get("padding_id", 0) loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch)) with tf.variable_scope("reduce_mean_final"): loss = mtf.reduce_mean(loss_batch) loss += aux_losses # Add on auxiliary losses (currently only used for MoE) loss /= params["num_microbatches"] # Convert to train dtype loss = mtf.cast(loss, variable_dtype.slice_dtype) else: loss = None loss_batch = None # Cast back to checkpoint dtype logits = mtf.cast(logits, variable_dtype.master_dtype) return logits, loss, loss_batch
def sample_autoregressive(self, partial_sequences, dst_attributes=None, stop_at_token=1, max_steps=None, temperature=0.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, z=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). The dst_attributes represents the destination attributes in which we want to generate sequences. Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. Returns: a Tensor with shape [<batch_dims>, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences attributes = dst_attributes batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None length_range = mtf.range(inputs.mesh, length_dim, tf.int32) if self.input_full_attention: read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] partial_sequences_eos_count = 0 else: initial_states = context_first_part.new_states partial_sequences_eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, position - 1, length_dim) else: attributes_this_step = None # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim)) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, 0)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs
def beam_search(self, inputs, decode_length, dst_attributes=None, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None, z=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim].# decode_length: an int32 mtf scalar. Maximum decode length. attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim] ([<batch_dims>] [<batch_dims>, beam_dim]). variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ attributes = dst_attributes if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, tf.int32) initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None if self.input_full_attention: # This only makes sense in the case of beam search with given partial # sequences, which is not yet implemented. # TODO(noam): implement raise NotImplementedError( "Beam search for language models not yet implemented") else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, step_num - 1, length_dim) else: attributes_this_step = None context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=step_num, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, attention_kwargs=None): """Attention to the a neighborood around the source. If autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer attention_kwargs: optional keyword arguments for attention() Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos illegal = mtf.not_equal(q_sequence_id, m_sequence_id) illegal = mtf.logical_or(illegal, mtf.less_equal(relative_position, -radius)) illegal = mtf.logical_or(illegal, mtf.greater( relative_position, 0 if autoregressive else radius)) mask = mtf.cast(illegal, q.dtype) * -1e9 o = attention(q, k, v, padded_memory_block_length, key_dim, value_dim, mask, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
def sample_autoregressive(self, partial_sequences, stop_at_token=1, max_steps=None, temperature=1.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, length_dim] """ del max_steps # TODO(noam): implement if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states] else: initial_states = context_first_part.new_states def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) is_done = past_end if stop_at_token is not None: has_eos = mtf.reduce_any( mtf.equal(ids, stop_at_token), reduced_dim=length_dim) is_done = mtf.logical_or(is_done, has_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position return outputs
def beam_search(self, inputs, decode_length, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]. decode_length: an int32 mtf scalar. Maximum decode length. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather( beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def sample_autoregressive( partial_sequences, other_features, params, stop_at_token=50256, max_steps=None, temperature=0.9, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, bos_id=50256, ): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. bos_id: beginning of sequence id Returns: a Tensor with shape [<batch_dims>, length_dim] """ inputs = partial_sequences # Partial sequences to fill in batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] padding_id = params.get("padding_id", 0) slow_sampling = params.get("slow_sampling", False) initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts length_range = mtf.range(inputs.mesh, length_dim, tf.int32) input_full_attention = True # for now hardcode this to true bc lazy if input_full_attention: # Vanilla autoregressive model - each position can see previous positions. # Think this feeds in to the loop fn and tells each position where it can attend to? read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range # Builds context to pass around internally # The 'first part' context records initial states of k / v / x if not slow_sampling: context_first_part = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) with tf.variable_scope("gpt2"): logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part) if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] else: initial_states = context_first_part.new_states else: initial_states = [] if not has_partial_sequences: partial_sequences_eos_count = 0 if stop_at_token is not None: partial_sequences_eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # Remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, padding_id)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs