def _mtf_model_fn(self, features, mesh): self._original_features = features hparams = self._hparams def import_feature(key): return self._import_feature(features, mesh, key) targets = import_feature("targets") if self.autoregressive: inputs = mtf.shift(targets, offset=1, dim=self.length_dim, wrap=False) else: inputs = import_feature("inputs") # TODO(noam): options for bert-style masking here? sequence_id = import_feature("targets_segmentation") model = self.model() logits, loss = model.call_simple(inputs=inputs, targets=targets, compute_loss=True, mode=hparams.mode, variable_dtype=self.variable_dtype, sequence_id=sequence_id) return logits, loss
def _mtf_model_fn(self, features, mesh): self._original_features = features hparams = self._hparams def import_feature(key): return self._import_feature(features, mesh, key) targets = import_feature("targets") sequence_id = import_feature("targets_segmentation") position = import_feature("targets_position") if self.autoregressive: inputs = mtf.shift(targets, offset=1, dim=self.length_dim, wrap=False) if position is not None: # first input in later sequences should be 0 inputs *= mtf.to_int32(mtf.not_equal(position, 0)) else: inputs = import_feature("inputs") # TODO(noam): options for bert-style masking here? model = self.model() logits, loss = model.call_simple(inputs=inputs, targets=targets, compute_loss=True, mode=hparams.mode, variable_dtype=self.variable_dtype, sequence_id=sequence_id, position=position) return logits, loss
def _mtf_model_fn(self, features, mesh): self._original_features = features hparams = self._hparams def import_feature(key): return self._import_feature(features, mesh, key) targets = import_feature("targets") sequence_id = import_feature("targets_segmentation") if hparams.use_global_position_in_packed_sequence: position = None else: position = import_feature("targets_position") if self.autoregressive: inputs = mtf.shift(targets, offset=1, dim=self.length_dim, wrap=False) # We should have a 0 at the beginning of each sequence rather than the # shifted EOS (1) from the previous sequence. inputs -= mtf.to_int32(mtf.equal(inputs, 1)) else: inputs = import_feature("inputs") # TODO(noam): options for bert-style masking here? model = self.model() logits, loss = model.call_simple(inputs=inputs, targets=targets, compute_loss=True, mode=hparams.mode, variable_dtype=self.variable_dtype, sequence_id=sequence_id, position=position) return logits, loss
def call_simple(self, inputs, targets, compute_loss, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), encoder_sequence_id=None, decoder_sequence_id=None): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType encoder_sequence_id: an optional Tensor decoder_sequence_id: an optional Tensor Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_output, encoder_loss = self.encoder.call_simple( inputs, None, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params) encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) if encoder_sequence_id is not None: encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) length_dim = targets.shape.dims[-1] shifted_targets = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) logits, loss = self.decoder.call_simple( shifted_targets, targets, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=decoder_sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params) if loss is not None and encoder_loss is not None: loss += encoder_loss return logits, loss
def halo_reduce(x, blocks_dim, block_size_dim, halo_size, wrap=True): """Reduce each block with the margins of adjacent blocks. Get left and right blocks_dim and sum overlap along block_size_dim. Only supports halo size smaller than block_size/2 Args: x: a Tensor. blocks_dim: a Dimension in x.shape block_size_dim: a Dimension in x.shape halo_size: an integer wrap: a boolean Returns: a Tensor with the same shape as x, other than in block_size_dim, whose size is increased by 2*halo_size. """ if halo_size == 0: return x block_size = block_size_dim.size assert halo_size <= block_size // 2 left_margin = mtf.slice(x, 0, 2 * halo_size, block_size_dim.name) right_margin = mtf.slice(x, block_size_dim.size - 2 * halo_size, 2 * halo_size, block_size_dim.name) center = mtf.slice(x, 2 * halo_size, block_size_dim.size - 4 * halo_size, block_size_dim.name) # Perform halo exchange sum margins left = mtf.shift(right_margin, 1, blocks_dim, wrap) + left_margin right = mtf.shift(left_margin, -1, blocks_dim, wrap) + right_margin # Recompose block left = mtf.pad(left, [0, block_size_dim.size - 2 * halo_size], block_size_dim.name) right = mtf.pad(right, [block_size_dim.size - 2 * halo_size, 0], block_size_dim.name) center = mtf.pad(center, [2 * halo_size, 2 * halo_size], block_size_dim.name) x = left + center + right return x
def model_fn(mtf_features): """The kind of function we need for mtf.serialize_training_step. Args: mtf_features: a dictionary Returns: a dictionary """ targets = mtf_features["targets"] if model_type == "lm": _, _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if num_microbatches > 1: loss /= float(num_microbatches) del logits return {"loss": loss}
def causal_depthwise_conv(x, context, kernel_size=3): """Causal depthwise convolution.""" def scale_var(shift_distance): return mtf.get_variable( context.mesh, "conv_%s" % shift_distance, mtf.Shape(context.model.ensemble_dims + x.shape.dims[-1:]), initializer=tf.constant_initializer(0.5 if shift_distance == 0 else 0.5 / kernel_size), dtype=context.variable_dtype) ret = x * scale_var(0) for shift_distance in range(1, kernel_size): x = mtf.shift(x, 1, context.length_dim, wrap=False) ret += x * scale_var(shift_distance) return ret
def _mtf_model_fn(self, features, mesh): self._original_features = features features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_length(x): extra_length = self.length_dim.size - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, self.length_dim.size]) return x targets = pad_to_length(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh) for key in [ "targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position" ]: if key in features: features[key] = pad_to_length(features[key]) if hparams.decoder_type == "autoregressive": shifted_targets = mtf.shift(targets, offset=1, dim=self.length_dim, wrap=False) else: raise ValueError("unknown hparams.decoder_type = %s" % hparams.decoder_type) model = self.model() logits, loss = model.call_simple(inputs=shifted_targets, targets=targets, compute_loss=True, mode=hparams.mode, variable_dtype=self.variable_dtype) # mesh_shape=hparams.mesh_shape, # layout=hparams.layout, return logits, loss
def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret
def call_simple(self, inputs, targets, compute_loss, attributes=None, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=None, subsequence_id=None, position=None, encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, layer_outputs=None, encoder_layer_outputs=None, z=None): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training autoregressive models this should be equal to mtf.shift(targets, offset=1, dim=length_dim, wrap=False) targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType sequence_id: an optional Tensor subsequence_id: an optional Tensor position: an optional Tensor encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary layer_outputs: an optional list to append Tensor layer activations to encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32) if not self.positional_embedding: # To make relative attention faster, we drop the information about the # position in the subsequence. The relative attention code then # assumes that the positions are given by index in the tensor, # which still leads to the correct computation of relative position. position = None if position is None: position_is_default = True position = length_range else: position_is_default = False if self.input_full_attention: # The inputs part of each sequence can fully attend within itself. full_attention_region = delimited_lm_inputs_mask(targets) # We can include one additional position to the right - the position # where the final EOS of the inputs is read and the first target token # is predicted. full_attention_region = mtf.logical_or( full_attention_region, mtf.shift(full_attention_region, offset=1, dim=length_dim, wrap=False)) # We set read_priority and write_priority to 0 in the full-attention # region and equal to the position elsewhere. read_priority = write_priority = length_range * mtf.cast( mtf.logical_not(full_attention_region), tf.int32) elif self.autoregressive: # Vanilla autoregressive model - each position can see previous positions. read_priority = write_priority = length_range else: read_priority = write_priority = None context = Context(model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode=mode, losses=[] if compute_loss else None, sequence_id=sequence_id, subsequence_id=subsequence_id, position=position, position_is_default=position_is_default, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=layer_outputs, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name): logits = self._call_internal(context, inputs, targets, attributes, z=z) if compute_loss: loss = mtf.add_n(context.losses) else: loss = None return logits, loss
def _mtf_model_fn(self, features, mesh): self._original_features = features features = copy.copy(features) hparams = self._hparams extra_losses = [] targets = tf.to_int32(features["targets"]) mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) is_training = mode == tf.estimator.ModeKeys.TRAIN if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) if hparams.decoder_type == "autoregressive": shifted_targets = mtf.shift( targets, offset=1, dim=self.length_dim, wrap=False) elif hparams.decoder_type == "denoising": shifted_targets = self._noisy_targets(targets, extra_losses) else: raise ValueError( "unknown hparams.decoder_type = %s" % hparams.decoder_type) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: decoder_self_attention_mask = None def layer_prepostprocess_dropout(x): return mtf.dropout( x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): # For some reason, the logits computation is extremely slow on TPU # in some cases where the batch size per core is 1. Reshape the logits # and the targets to double the batch size and halve the length. # TODO(noam): file a bug. old_dims = self.batch_dims + [self.length_dim] new_dims = self.batch_dims[:-1] + [ mtf.Dimension(self.batch_dims[-1].name, self.batch_dims[-1].size * 2), mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)] x = mtf.reshape(x, new_dims + [self.model_dim]) targets = mtf.reshape(targets, new_dims) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim]) logits = mtf.to_float(logits) return logits, loss
def logits_and_loss(mtf_features): """Compute logits and loss. Args: mtf_features: a dictionary Returns: logits: a mtf.Tensor loss: a mtf.Tensor """ if model_type == "lm": # TOTRY Adapt that to our case if "inputs" in mtf_features: mtf_features = _dynamic_text2self(mtf_features) _, _, length_dim = mtf_features["targets"].shape inputs = mtf.shift(mtf_features["targets"], offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if control_codes: codeprefixedtargets = mtf_features["codeprefixedtargets"] else: codeprefixedtargets = None if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": if control_codes: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "codeprefixedtargets_segmentation", None), decoder_subsequence_id=mtf_features.get( "codeprefixedtargets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "codeprefixedtargets_position", None), ) else: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), decoder_subsequence_id=mtf_features.get( "targets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") if isinstance(transformer_model, Bitransformer_ll): if cycle_consistency_loss: logits_ae, l_ae = transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None with gin.config_scope('training'): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # mtf_samples = mtf.anonymize(mtf_samples) outputs = mtf_samples logits_cycle, l_cycle = transformer_model.call_simple( inputs=outputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle return logits_cycle, loss_ae_cycle else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), num_microbatches=num_microbatches, **position_kwargs)
def sample_autoregressive(self, partial_sequences, dst_attributes=None, stop_at_token=1, max_steps=None, temperature=0.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, z=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). The dst_attributes represents the destination attributes in which we want to generate sequences. Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. Returns: a Tensor with shape [<batch_dims>, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences attributes = dst_attributes batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None length_range = mtf.range(inputs.mesh, length_dim, tf.int32) if self.input_full_attention: read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] partial_sequences_eos_count = 0 else: initial_states = context_first_part.new_states partial_sequences_eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, position - 1, length_dim) else: attributes_this_step = None # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim)) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, 0)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs
def gradient_based_subword_tokenization(x, length_dim, max_subword_length=4, downsample=None, use_offsets=False, consider_chars_as_blocks=False, use_block_pos_embedding=False, share_block_kernel=False, memory_embeddings=0, context=None, block_mixing_mode=None, activation="softmax", downsample_function="mean"): """Implements GBSWT from Charformer. Args: x: a Tensor containing length_dim length_dim: a Dimension max_subword_length: integer downsample: integer. use_offsets: boolean. consider_chars_as_blocks: boolean. use_block_pos_embedding: boolean. share_block_kernel: boolean. memory_embeddings: integer. context: Context. block_mixing_mode: Str for block mixing. activation: Str for block ranking. downsample_function: Str, supports mean/linformer for now. Returns: a Tensor with the same shape as x. Raises: ValueError: if channels or depth don't match. """ # don't use this for now. del max_subword_length del memory_embeddings all_blocks = [] all_scores = [] tf.logging.info("GSW block layer") def _tile(x, n, tile_dim): # Simple tile function in MTF. return mtf.concat([x] * n, tile_dim.name) def _repeat(x, n, repeat_dim): # repeat function in MTF tmp_dim = mtf.Dimension("tmp", 1) expand_shape = mtf.Shape(x.shape.dims + [tmp_dim]) x = mtf.reshape(x, expand_shape) x = _tile(x, n, tmp_dim) output_shape = [] for dim in x.shape.dims: if dim.name == "tmp": continue if dim.name == repeat_dim.name: dim = mtf.Dimension(dim.name, dim.size * n) output_shape.append(dim) output_shape = mtf.Shape(output_shape) x = mtf.reshape(x, output_shape) return x def _combined_dim(dims): return mtf.Dimension(dims[0].name, mtf.Shape(dims).size) # compute all subword blocks # TODO(yitay): handle offsets to get all blocks if activation == "sigtanh": # one score for sigmoid tmp_dim = mtf.Dimension("block_score", 2) else: tmp_dim = mtf.Dimension("block_score", 1) model_dim = x.shape[-1] subword_blocks_width = [2, 3, 4] if consider_chars_as_blocks: subword_blocks_width += [1] if share_block_kernel: block_kernel_shape = mtf.Shape([model_dim, tmp_dim]) block_kernel = mtf.get_variable(x.mesh, "block_kernel", block_kernel_shape, initializer=None, dtype=context.variable_dtype) else: block_kernel = None for subword_len in subword_blocks_width: if use_block_pos_embedding: # this is turn off by default. It is meant to support cases like # parameterized pooling or other features. block_len_dim = mtf.Dimension(length_dim.name, subword_len) # TODO(vqtran): Consider other positional embeddings. block_pos_emb = sinusoid_positional_embedding_weights( context.mesh, block_len_dim, x.shape[-1], context.variable_dtype.activation_dtype) block_pos_emb = _repeat( block_pos_emb, math.ceil(length_dim.size / float(subword_len)), block_len_dim) if use_offsets: offset_space = subword_len else: offset_space = 1 for offsets in range(offset_space): if offsets > 0: xoff = mtf.shift(x, offsets, length_dim, wrap=False) if use_block_pos_embedding: block_pos_emb = mtf.shift(block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False) else: xoff = x tf.logging.info("SW len=%d offset=%d", subword_len, offsets) if length_dim.size % subword_len != 0: tf.logging.info("Not divisible by length") # add extra padding tokens pad_amt = int(subword_len) - int(length_dim.size % subword_len) kp = mtf.pad(xoff, [0, pad_amt], length_dim.name) else: kp = xoff if use_block_pos_embedding: kp += block_pos_emb bx = mtf.pool_tensor_1d( kp, pool_dim=kp.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(subword_len)) block_score = mtf.layers.dense(bx, [tmp_dim], use_bias=False, name="bx", reduced_dims=[model_dim], variable_dtype=None, kernel_weights=block_kernel) expand_bx = _repeat(bx, subword_len, length_dim) expand_scores = _repeat(block_score, subword_len, length_dim) if offsets > 0: # add offset. expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name) expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name) new_len = expand_bx.shape.get_dim_by_name(length_dim.name) if new_len.size < length_dim.size: pad_amt = new_len.size - length_dim.size expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name) expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name) elif new_len.size > length_dim.size: expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name) expand_scores = mtf.slice(expand_scores, 0, length_dim.size, length_dim.name) new_tmp_dim = mtf.Dimension("extra_dim", 1) expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim]) expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim]) expand_bx = mtf.reshape(expand_bx, expand_shape) expand_scores = mtf.reshape(expand_scores, expand_scores_shape) all_blocks.append(expand_bx) all_scores.append(expand_scores) all_blocks = mtf.concat(all_blocks, new_tmp_dim.name) all_scores = mtf.concat(all_scores, new_tmp_dim.name) tf.logging.info(all_blocks) new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim") combined_dim = _combined_dim([new_tmp_dim, tmp_dim]) block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim block_net = mtf.reshape(all_scores, block_net_shape) if block_mixing_mode == "score_attention": tf.logging.info("Using score attention") att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim]) tf.logging.info(block_net) att = mtf.softmax(att, reduced_dim=att.shape[-1]) block_net = mtf.einsum([att, block_net], output_shape=block_net.shape) tf.logging.info(block_net) if activation == "softmax": block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim) elif activation == "tanh": tf.logging.info("Using tanh") block_net = mtf.tanh(block_net) all_blocks = block_net * all_blocks all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim) output = all_blocks if downsample: output_length = output.shape.get_dim_by_name("length") if output_length.size % int(downsample) != 0: pad_amt = int(downsample) - int( output_length.size % int(downsample)) output = mtf.pad(output, [0, pad_amt], output_length.name) if downsample_function == "mean": output = mtf.pool_tensor_1d( output, pool_dim=output.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(downsample)) else: raise ValueError("Downsampling function not implemeneted.") return output
def beam_search(self, inputs, decode_length, dst_attributes=None, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None, z=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim].# decode_length: an int32 mtf scalar. Maximum decode length. attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim] ([<batch_dims>] [<batch_dims>, beam_dim]). variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ attributes = dst_attributes if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, tf.int32) initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None if self.input_full_attention: # This only makes sense in the case of beam search with given partial # sequences, which is not yet implemented. # TODO(noam): implement raise NotImplementedError( "Beam search for language models not yet implemented") else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, step_num - 1, length_dim) else: attributes_this_step = None context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=step_num, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ batch_dim = mtf.Dimension("batch", batch_size) length_dim = mtf.Dimension("length", length) mtf_shape = mtf.Shape([batch_dim, length_dim]) if key not in features: return None x = tf.to_int32(features[key]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") if text2self: mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=variable_dtype, temperature=temperature) else: mtf_samples = transformer_model.decode( inputs, variable_dtype=variable_dtype, beam_size=beam_size, alpha=alpha, temperature=temperature) mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if text2self: _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if text2self: position_kwargs = dict( sequence_id=_import_feature("targets_segmentation"), position=_import_feature("targets_position"), ) else: position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation"), decoder_sequence_id=_import_feature("targets_segmentation"), encoder_position=_import_feature("inputs_position"), decoder_position=_import_feature("targets_position"), ) logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=variable_dtype, **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def padded_neg_log_perplexity(logits, labels): weights = tf.to_float(tf.not_equal(labels, 0)) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) return { "neg_log_perplexity": tf.metrics.mean(-xent, weights) } labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def sample_autoregressive(self, partial_sequences, stop_at_token=1, max_steps=None, temperature=1.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, length_dim] """ del max_steps # TODO(noam): implement if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states] else: initial_states = context_first_part.new_states def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) is_done = past_end if stop_at_token is not None: has_eos = mtf.reduce_any( mtf.equal(ids, stop_at_token), reduced_dim=length_dim) is_done = mtf.logical_or(is_done, has_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position return outputs
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key, allow_missing=False): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string allow_missing: a boolean Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) if key not in features: if allow_missing: return None else: raise ValueError("feature not found %s - features %s = " % (key, features)) tf.logging.info("Import feature %s: %s" % (key, features[key])) x = tf.to_int32(features[key]) x = tf.reshape( x, [outer_batch_size, batch_size // outer_batch_size, -1]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if model_type == "lm": _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if mode == tf.estimator.ModeKeys.EVAL: if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) labels = lowering.export_to_tf_tensor(anon_targets) restore_hook = mtf.MtfRestoreHook(lowering) # metric_names becomes locally scoped if we simply assign # ["padded_neg_log_perplexity"] to it conditioned on if it's None. local_metric_names = metric_names or ["token_accuracy"] def metric_fn(labels, outputs): return get_metric_fns(local_metric_names, labels, outputs) eval_metrics = (metric_fn, [labels, outputs]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, # Unfortunately TPUEstimatorSpec requires us to provide a value for # loss when in EVAL mode. Since we are sampling or decoding from the # model, we don't have a loss to report. loss=tf.constant(0.), evaluation_hooks=[restore_hook], eval_metrics=eval_metrics) if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=_import_feature("targets_segmentation", True), position=_import_feature("targets_position", True), ) elif isinstance(transformer_model, transformer.Bitransformer): position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation", True), decoder_sequence_id=_import_feature("targets_segmentation", True), encoder_position=_import_feature("inputs_position", True), decoder_position=_import_feature("targets_position", True), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
def beam_search(self, inputs, decode_length, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]. decode_length: an int32 mtf scalar. Maximum decode length. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather( beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def _mtf_model_fn(self, features, mesh): self._original_features = features features = copy.copy(features) hparams = self._hparams extra_losses = [] targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) if hparams.decoder_type == "autoregressive": shifted_targets = mtf.shift( targets, offset=1, dim=self.length_dim, wrap=False) elif hparams.decoder_type == "denoising": shifted_targets = self._noisy_targets(targets, extra_losses) else: raise ValueError( "unknown hparams.decoder_type = %s" % hparams.decoder_type) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: decoder_self_attention_mask = None def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): # For some reason, the logits computation is extremely slow on TPU # in some cases where the batch size per core is 1. Reshape the logits # and the targets to double the batch size and halve the length. # TODO(noam): file a bug. old_dims = self.batch_dims + [self.length_dim] new_dims = self.batch_dims[:-1] + [ mtf.Dimension(self.batch_dims[-1].name, self.batch_dims[-1].size * 2), mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)] x = mtf.reshape(x, new_dims + [self.model_dim]) targets = mtf.reshape(targets, new_dims) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim]) logits = mtf.to_float(logits) return logits, loss