def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) x_dim = mtf.Dimension("nx", FLAGS.cube_size) y_dim = mtf.Dimension("ny", FLAGS.cube_size) z_dim = mtf.Dimension("nz", FLAGS.cube_size) tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) ty_dim = mtf.Dimension("tny", FLAGS.cube_size) tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) # Create field field = mtf.random_normal(mesh, [batch_dim, x_dim, y_dim, z_dim]) input_field = field field = mtf.cast(field, tf.complex64) err = 0 # Performs several back and forth FFTs in the same session for i in range(FLAGS.n_ffts): # Apply FFT fft_field = mpm.fft3d(field, [tx_dim, ty_dim, tz_dim]) # Inverse FFT field = mpm.ifft3d(fft_field * 1, [x_dim, y_dim, z_dim]) err += mtf.reduce_max(mtf.abs(mtf.cast(field, tf.float32) - input_field)) field = mtf.cast(field, tf.float32) # Compute errors err += mtf.reduce_max(mtf.abs(field - input_field)) return err
def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2' _gp = lambda x, alpha: x ** (alpha - 1) _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1))) _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha) dim = x.shape[-1] if dim is None else dim d = dim.size x = x * (alpha - 1) max_val = mtf.reduce_max(x, reduced_dim=dim) tau_lo = max_val - _gp(1, alpha) tau_hi = max_val - _gp(1 / d, alpha) f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1 dm = tau_hi - tau_lo for _ in range(n_iter): dm = dm / 2 tau_m = tau_lo + dm p_m = _p(x - tau_m, alpha) f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1 mask = mtf.greater_equal((f_m * f_lo), 0) tau_lo = mtf.where(mask, tau_m, tau_lo) p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim) return p_m
def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) x_dim = mtf.Dimension("nx", FLAGS.cube_size) y_dim = mtf.Dimension("ny", FLAGS.cube_size) z_dim = mtf.Dimension("nz", FLAGS.cube_size) tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) ty_dim = mtf.Dimension("tny", FLAGS.cube_size) tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) # Create field field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim]) # Apply FFT fft_field = mpm.fft3d(mtf.cast(field, tf.complex64), [tx_dim, ty_dim, tz_dim]) # Inverse FFT rfield = mtf.cast(mpm.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32) # Compute errors err = mtf.reduce_max(mtf.abs(field - rfield)) return err
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, decode_length_multiplier=1.5, decode_length_constant=10): """Sampling or beam search. TODO(noam): should we make the output length dimension different from the input length dimension? Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. decode_length_multiplier: a float decode_length_constant: a float Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output) encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) if beam_size == 1: ids_shape = inputs.shape partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") # beam search beam_dim = mtf.Dimension("beam", beam_size) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad(inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum([encoder_output, k_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) v = mtf.einsum([encoder_output, v_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError("hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim ]) local_kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, local_attention_window, self.kv_dim ]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend([ mtf.zeros( mesh, local_kv_shape, dtype=self.activation_dtype) ] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode(logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, sampling_keep_top_k=-1, decode_length_multiplier=1.5, decode_length_constant=10, max_decode_length=None): """Sampling or beam search for Funnel Transformer. Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sampling_keep_top_k: a value between 1 and vocab_size used to sample from only the k most likely logits. Set to -1 to sample from all logits. decode_length_multiplier: a float decode_length_constant: a float max_decode_length: an optional integer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) # The sequence_id is updated inside the layer_stack due to pooling. So we # need to use the updated sequence_id stored in the context. encoder_sequence_id = self.encoder.layer_stack.context.sequence_id encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] if max_decode_length is None: decode_length_dim = length_dim else: decode_length_dim = mtf.Dimension("length", max_decode_length) if beam_size == 1: ids_shape = mtf.Shape(batch_dims + [decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, sampling_keep_top_k=sampling_keep_top_k, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length( inputs), shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") if sampling_keep_top_k != -1: raise ValueError( "don't know how to beam search with top-k value other than -1." ) # beam search beam_dim = mtf.Dimension("beam", beam_size) ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=inputs, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)