def c2r3d(cfield, dims, norm=None, dtype=tf.float32, name=None): """ Converts a complex Fourier domain field to a real field Parameters: ----------- cfield: tensor (batch_size, nc, nc, nc) Complex 3D real field norm: float Normalization factor dtype: tf.dtype Type of output tensor Return: ------- rfield: tensor (batch_size, nc, nc, nc) Real valued field """ x_dim, y_dim, z_dim = cfield.shape[-3:] if norm is None: norm = mtf.constant(cfield.mesh, x_dim.size * y_dim.size * z_dim.size) rfield = mtf.cast(mesh_ops.ifft3d(cfield, dims), dtype) * norm return rfield
def positional_encoding(x): seq_len_dim, model_dim = x.shape[1:] seq_len_size = seq_len_dim.size model_size = model_dim.size # mtf.constant is only to create a tensor with a constant scalar val. But as # long as the tensor is fully replicated, initializing it with a tensor # works. assert (not seq_len_dim.name.startswith('axis')) assert (not model_dim.name.startswith('axis')) # Values for positional encoder pos = np.arange(seq_len_size).reshape(-1, 1) val = np.power(10000, (2 * np.arange(model_size)) / model_size, dtype=float) pos_enc_values = pos / val np.sin(pos_enc_values[:, ::2], out=pos_enc_values[:, ::2], dtype=np.float32) np.cos(pos_enc_values[:, 1::2], out=pos_enc_values[:, 1::2], dtype=np.float32) # positional encoder pos_enc = mtf.constant(x.mesh, pos_enc_values, shape=mtf.Shape([seq_len_dim, model_dim]), dtype=tf.float32) return (x * math.sqrt(model_size)) + pos_enc
def r2c3d(rfield, k_dims, norm=None, dtype=tf.complex64): """ Converts a real field to its complex Fourier Transform Parameters: ----------- rfield: tensor (batch_size, nc, nc, nc) Input 3D real field norm: float Normalization factor dtype: tf.dtype Type of output tensor Return: ------- cfield: tensor (batch_size, nc, nc, nc) Complex field """ x_dim, y_dim, z_dim = rfield.shape[-3:] if norm is None: norm = mtf.constant(rfield.mesh, x_dim.size * y_dim.size * z_dim.size) cfield = mesh_ops.fft3d(mtf.cast(rfield / norm, dtype), k_dims) return cfield
def test_conv1d_update_state(self): batch = 2 d_model = 6 filter_size = 3 batch_dim = mtf.Dimension("batch", batch) filter_dim = mtf.Dimension("filter", filter_size) x = np.random.randn(batch, d_model) x_mtf = self.converter.convert_np_array_to_mtf_tensor( x, dtype=tf.float32, dim_names=["batch", "d_model"]) old_state = np.random.randn(batch, filter_size, d_model) old_state_mtf = self.converter.convert_np_array_to_mtf_tensor( old_state, dtype=tf.float32, dim_names=["batch", "filter", "d_model"]) position_mtf = mtf.constant(self.converter.mesh, filter_size - 1, shape=mtf.Shape([batch_dim]), dtype=tf.int32) conv_layer = transformer_layers.Conv1D() output_mtf = conv_layer.update_state(old_state_mtf, x_mtf, position_mtf, filter_dim, dtype=tf.float32) actual = self.converter.convert_mtf_tensor_to_np_array(output_mtf) expected = np.empty(shape=old_state.shape) expected[:, :filter_size - 1, :] = old_state[:, 1:, :] expected[:, -1, :] = x self.assertAllClose(actual, expected)
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) lr, update_ops = optimization_lib.create_optimizer( loss, 0.2, 100, 10) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append( tf.assign_add(tf.train.get_or_create_global_step(), 1)) train_op = tf.group(tf_update_ops) return lr, train_op
def sample(self, features, mesh): hparams = self._hparams model = self.model() def import_feature(key): return self._import_feature(features, mesh, key) if self.autoregressive: # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = import_feature("inputs") if partial_targets is None: partial_targets = import_feature("targets") if partial_targets: partial_targets *= mtf.cast(mtf.not_equal(partial_targets, 1), partial_targets.dtype) else: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) if hparams.beam_size > 1: raise NotImplementedError( "Beam search not implemented for unitransformer.") ret = model.sample_autoregressive( partial_targets, temperature=hparams.sampling_temp, variable_dtype=self.variable_dtype) return self.combine_batch_dims(ret) else: raise ValueError( "Don't know how to sample from non-autoregressive unitransformer" )
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=0.2) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] return tf.group(tf_update_ops)
def test_convert_mtf_tensor_to_np_array(self): x_np = np.array([[1, 2, 3], [4, 5, 6]]) converter = test_utils.NumpyConverter() shape = mtf.Shape([mtf.Dimension("dim0", 2), mtf.Dimension("dim1", 3)]) x_mtf = mtf.constant(converter.mesh, x_np, shape=shape, dtype=tf.int32) actual = converter.convert_mtf_tensor_to_np_array(x_mtf) self.assertAllEqual(x_np, actual)
def testWhileLoopOperation(self): # This test case implements the following: # for i in range(10): # x = x * 2 i = mtf.constant(self.mesh, 0, mtf.Shape([])) cond_fn = lambda i, x: mtf.less(i, 10) body_fn = lambda i, x: [mtf.add(i, 1), mtf.multiply(x, 2)] while_loop_operation = mtf.WhileLoopOperation(cond_fn, body_fn, [i, self.x]) self.assertEqual(while_loop_operation.splittable_dims, frozenset(["a", "b"])) self.assertEqual(while_loop_operation.unsplittable_dims, frozenset())
def convert_np_array_to_mtf_tensor(self, x, dim_names=None, dtype=tf.int32): """Convert a numpy array to an equivalent mtf.Tensor.""" dim_sizes = x.shape if not dim_names: dim_names = [f"dim{i}" for i in range(len(dim_sizes))] dims = [] for dim_size, dim_name in zip(dim_sizes, dim_names): dims.append(mtf.Dimension(dim_name, dim_size)) shape = mtf.Shape(dims) x_mtf = mtf.constant(self.mesh, x, shape=shape, dtype=dtype) return x_mtf
def get_dummy_decoder_context(converter, batch=2, d_model=6, length=4, mode="incremental", initial_position=None, state=None, inputs=None): batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) # Set up a dummy model layer_stack = transformer.LayerStack(layers=[]) model = transformer.Unitransformer( d_model=d_model, input_vocab_size=10, # dummy values output_vocab_size=10, # dummy values autoregressive=True, max_length=length, layer_stack=layer_stack) if state is not None: state_mtf = converter.convert_np_array_to_mtf_tensor( state, dtype=tf.float32, dim_names=["batch", "length", "d_model"]) states = [state_mtf] else: states = None if initial_position: initial_position = mtf.constant(converter.mesh, initial_position, shape=mtf.Shape([batch_dim]), dtype=tf.int32) if inputs is not None: inputs = converter.convert_np_array_to_mtf_tensor( inputs, dim_names=["batch", "length"]) context = transformer.Context(model=model, mode=mode, states=states, new_states=[], mesh=converter.mesh, batch_dims=[batch_dim], length_dim=length_dim, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=1, inputs=inputs, initial_position=initial_position) return context
def sample(self, features, mesh): hparams = self._hparams model = self.model() # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, self.length_dim.size - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh) # strip EOS partial_targets *= mtf.to_int32(mtf.not_equal(partial_targets, 1)) else: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) if hparams.beam_size == 1: pass else: raise NotImplementedError("not implemented") # beam_dim = mtf.Dimension("beam", hparams.beam_size) # ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) partial_targets = mtf.Print(partial_targets, [partial_targets], "Partial_Targets", summarize=1000) return model.sample_autoregressive(partial_targets, temperature=hparams.sampling_temp, variable_dtype=self.variable_dtype)
def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states
def get_masked_lm_output(self, positions, label_ids, label_weights): """Get loss and logits for the masked LM.""" input_tensor = self.get_sequence_output() output_weights = self.get_embedding_table() # [batch_size, num_position, hidden] input_tensor = mtf.gather(input_tensor, positions, self.seq_dim) with tf.variable_scope("cls/predictions"): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = mtf.layers.dense( input_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=get_activation(self.config.feedforward_intermediate_act), kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias) input_tensor = self.normalize(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = mtf.get_variable( input_tensor.mesh, name="output_bias", shape=[self.vocab_dim], initializer=tf.zeros_initializer()) logits = mtf.einsum([input_tensor, output_weights], reduced_dims=[self.model_dim]) + output_bias per_example_loss = mtf.layers.softmax_cross_entropy_with_logits( logits, label_ids, self.vocab_dim, z_loss=1e-4) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. numerator = mtf.reduce_sum(label_weights * per_example_loss) denominator = mtf.reduce_sum(label_weights) + mtf.constant( input_tensor.mesh, 1e-5, dtype=tf.float32) loss = numerator / denominator return (loss, per_example_loss, logits)
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad(inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum([encoder_output, k_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) v = mtf.einsum([encoder_output, v_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError("hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim ]) local_kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, local_attention_window, self.kv_dim ]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend([ mtf.zeros( mesh, local_kv_shape, dtype=self.activation_dtype) ] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode(logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=1.0, decode_length_multiplier=1.5, decode_length_constant=10): """Sampling or beam search. TODO(noam): should we make the output length dimension different from the input length dimension? Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) decode_length_multiplier: a float decode_length_constant: a float Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) if beam_size == 1: ids_shape = inputs.shape partial_targets = mtf.constant(inputs.mesh, 0, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_targets, temperature=temperature, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") # beam search beam_dim = mtf.Dimension("beam", beam_size) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim]) partial_targets = mtf.constant(inputs.mesh, 0, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_targets, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, alpha=alpha, shared_params=shared_params)
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): """Creates and returns an optimizer training op.""" global_step = tf.train.get_or_create_global_step() # set defaults end_step = params.get("lr_decay_end", params["train_steps"]) lr_decay = params.get("lr_decay", "cosine") warmup_steps = params.get("warmup_steps", 3000) gradient_clipping = params.get("gradient_clipping", 1.0) optimizer_name = params.get("optimizer", "adam") learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype) clip_value = mtf.constant(mesh, gradient_clipping, dtype=variable_dtype.slice_dtype) if inp_var_grads is None: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in mesh.graph.trainable_variables]) else: var_grads = inp_var_grads valid_grads_vars = list( filter(lambda grad_var: grad_var[0] is not None, zip(var_grads, mesh.graph.trainable_variables))) valid_vars = [var for grad, var in valid_grads_vars] valid_grad = [grad for grad, var in valid_grads_vars] tf.logging.info([ v for v in zip(var_grads, [v.outputs[0] for v in mesh.graph.trainable_variables]) ]) # Cast to full precision var_grads_fp = [ mtf.cast(v, variable_dtype.slice_dtype) for v in valid_grad ] if lr_decay == "linear": learning_rate = tf.train.polynomial_decay( learning_rate, global_step, end_step, end_learning_rate=params["lr"] * 0.1, # Decrease to 10% of initial LR according to GPT-3 paper power=1.0, cycle=False) elif lr_decay == "cosine": learning_rate = tf.train.cosine_decay( learning_rate, global_step, end_step, alpha=0.1 # Alpha is min lr value as a fraction of init lr. ) if warmup_steps > 0: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(warmup_steps, dtype=tf.int32) dtype = variable_dtype.slice_dtype global_steps_float = tf.cast(global_steps_int, dtype) warmup_steps_float = tf.cast(warmup_steps_int, dtype) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = learning_rate * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype) learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate") scalar_summary("lr", learning_rate) if optimizer_name.lower() == "adam": optimizer = mtf.optimize.AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=params.get("weight_decay", 0.0), beta_1=params.get("beta_1", 0.9), beta_2=params.get("beta_2", 0.999), epsilon=params.get("epsilon", 1e-6), exclude_from_weight_decay=["norm", "bias"]) elif optimizer_name.lower() == "adafactor": optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate, decay_rate=params.get("weight_decay", 0.0), beta1=params.get("beta_1", 0.9), epsilon1=params.get("epsilon_1", 1e-30), epsilon2=params.get("epsilon_2", 1e-3)) else: raise ValueError(f"{optimizer_name} not recognized") if gradient_clipping is not None: (var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value) update_ops = optimizer.apply_grads(var_grads_fp, valid_vars) return learning_rate, update_ops, var_grads_fp
def scalar(v, dtype): return mtf.constant(mesh, v, shape=[], dtype=dtype)
def test_get_indices(self): key_size = 2 n_keys = 3 product_size = 2 head_size = 2 batch = 2 seq_len = 2 knn = 2 n_key_dim = mtf.Dimension("n_keys", n_keys) key_dim = mtf.Dimension("key", key_size // 2) seq_dim = mtf.Dimension("length", seq_len) batch_dim = mtf.Dimension("batch", batch) head_dim = mtf.Dimension("n_heads", head_size) product_dim = mtf.Dimension("product_key", product_size) knn_dim = mtf.Dimension("knn", knn) query_shape = mtf.Shape( [batch_dim, seq_dim, head_dim, product_dim, key_dim]) keys_shape = mtf.Shape([head_dim, product_dim, n_key_dim, key_dim]) query = mtf.ones(self.mesh, query_shape) keys_vals = [ [ [[4], [1], [2]], [[2], [-1], [2]], ], [ [[1], [2], [5]], [[6], [1], [4]], ], ] # h1: # First scores: # [4, 2] # [2, 2] # Cartesian added scores: # [6, 6] # Indices: # [0, 2] [0*n_k + 0, 0*n_k + 2] # h2: # First scores: # [5, 2] # [6, 4] # Cartesian added scores: # [11, 9] # Indices: # [6, 8] [2*n_k+0, 2*n_k+2] expected_scores = np.broadcast_to(np.array([[6, 6], [11, 9]]), [batch, seq_len, head_size, knn]) expected_indices = np.broadcast_to(np.array([[0, 2], [6, 8]]), [batch, seq_len, head_size, knn]) keys = mtf.constant(self.mesh, keys_vals, keys_shape) pkm = memory_layers.ProductKeyValueMemory(key_size, n_keys, head_size, knn) mtf_scores, mtf_indices = pkm.get_indices(keys, query) # Shapes. expected_shape = mtf.Shape([batch_dim, seq_dim, head_dim, knn_dim]) self.assertEqual(expected_shape, mtf_scores.shape) self.assertEqual(expected_shape, mtf_indices.shape) # Values lowering_s, scores = self._export_to_tf_tensor(mtf_scores) lowering_i, indices = self._export_to_tf_tensor(mtf_indices) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering_s.copy_masters_to_slices()) self.evaluate(lowering_i.copy_masters_to_slices()) scores, indices = self.evaluate([scores, indices]) self.assertAllEqual(expected_scores, scores) self.assertAllEqual(expected_indices, indices)
def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, position - 1, length_dim) else: attributes_this_step = None # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim)) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states
def beam_search(self, inputs, decode_length, dst_attributes=None, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None, z=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim].# decode_length: an int32 mtf scalar. Maximum decode length. attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim] ([<batch_dims>] [<batch_dims>, beam_dim]). variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ attributes = dst_attributes if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, tf.int32) initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None if self.input_full_attention: # This only makes sense in the case of beam search with given partial # sequences, which is not yet implemented. # TODO(noam): implement raise NotImplementedError( "Beam search for language models not yet implemented") else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, step_num - 1, length_dim) else: attributes_this_step = None context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=step_num, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def act_layer(self, context, x, mask): """Build a Universal Transformer ACT layer.""" state = x act_max_steps = self.act_max_steps threshold = 1.0 - self.act_epsilon state_shape_static = state.shape.dims state_slice = slice(0, 3) if self.act_type == "global": state_slice = slice(0, 2) # Dynamic shape for update tensors below update_shape = state_shape_static[state_slice] # Halting probabilities (p_t^n in the paper) halting_probability = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Remainders (R(t) in the paper) remainders = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Number of updates performed (N(t) in the paper) n_updates = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Previous cell states (s_t in the paper) previous_state = mtf.zeros_like(state) step = mtf.constant(context.mesh, 0, dtype=tf.int32) def ut_function(state, step, halting_probability, remainders, n_updates, previous_state): """implements act (position-wise halting). Args: state: 3-D Tensor: [batch_size, length, channel] step: indicates number of steps taken so far halting_probability: halting probability remainders: act remainders n_updates: act n_updates previous_state: previous state Returns: transformed_state: transformed state step: step+1 halting_probability: halting probability remainders: act remainders n_updates: act n_updates new_state: new state """ state = self.step_preprocess(context, state, step) if self.act_type == "random": # random as halting probability p = mtf.random_uniform(context.mesh, shape=halting_probability.shape.dims, dtype=context.variable_dtype) else: last_dim_name = state.shape.dimension_names[-1] new_dims = [mtf.Dimension(last_dim_name, 1)] with tf.variable_scope("sigmoid_activation_for_pondering", reuse=tf.AUTO_REUSE): p = mtf.layers.dense(state, variable_dtype=context.variable_dtype, reduced_dims=[state.shape.dims[-1]], new_dims=new_dims, activation=mtf.sigmoid, use_bias=True) if self.act_type == "global": # average over all positions (as a global halting prob) p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1]) p = mtf.squeeze(p) else: # maintain position-wise probabilities new_shape = p.shape.dims[:-1] p = mtf.reshape(p, new_shape) # Mask for inputs which have not halted yet still_running = mtf.cast(mtf.less(halting_probability, 1.0), context.activation_dtype) # Mask of inputs which halted at this step new_halted = mtf.cast( mtf.greater(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Mask of inputs which haven't halted, and didn't halt this step still_running = mtf.cast( mtf.less_equal(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Add the halting probability for this step to the halting # probabilities for those input which haven't halted yet halting_probability += p * still_running # Compute remainders for the inputs which halted at this step remainders += new_halted * (1 - halting_probability) # Add the remainders to those inputs which halted at this step halting_probability += new_halted * remainders # Increment n_updates for all inputs which are still running n_updates += still_running + new_halted # Compute the weight to be applied to the new state and output # 0 when the input has already halted # p when the input hasn't halted yet # the remainders when it halted this step input_tensor = p * still_running + new_halted * remainders update_weights = input_tensor # apply transformation on the state transformed_state = state for _ in range(self.num_inrecurrence_layers): transformed_state = self.vanilla_transformer_layer( context, transformed_state, mask) # update running part in the weighted state and keep the rest new_state = ((transformed_state * update_weights) + (previous_state * (1 - update_weights))) if self.act_type == "accumulated": # Add in the weighted state new_state = (transformed_state * update_weights) + previous_state step += 1 return (transformed_state, step, halting_probability, remainders, n_updates, new_state) for _ in range(act_max_steps + 1): (state, step, halting_probability, remainders, n_updates, previous_state) = ut_function(state, step, halting_probability, remainders, n_updates, previous_state) ponder_times = n_updates mtf.scalar_summary("ponder_times", mtf.reduce_mean(ponder_times)) return previous_state
def get_timing_signal_1d(self, context, length, channels, min_timescale=1.0, max_timescale=1.0e4, start_index=0): """Gets a bunch of sinusoids of different frequencies. Each channel of the input Tensor is incremented by a sinusoid of a different frequency and phase. This allows attention to learn to use absolute and relative positions. Timing signals should be added to some precursors of both the query and the memory inputs to attention. The use of relative position is possible because sin(x+y) and cos(x+y) can be expressed in terms of y, sin(x) and cos(x). In particular, we use a geometric sequence of timescales starting with min_timescale and ending with max_timescale. The number of different timescales is equal to channels / 2. For each timescale, we generate the two sinusoidal signals sin(timestep/timescale) and cos(timestep/timescale). All of these sinusoids are concatenated in the channels dimension. Args: context: mtf context. length: a mtf.Dimension, length of timing signal sequence. channels: a mtf.Dimension, size of timing embeddings to create. The number of different timescales is equal to channels / 2. min_timescale: a float max_timescale: a float start_index: index of first position Returns: a Tensor of timing signals [1, length, channels] """ position = context.get_position() + start_index num_timescales = mtf.constant(context.mesh, channels.size // 2) log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / mtf.maximum(num_timescales - 1, 1)) channel_dim_name = channels.name inv_timescales = (min_timescale * mtf.exp( mtf.mtf_range(context.mesh, mtf.Dimension(channel_dim_name, channels.size // 2), context.activation_dtype) * -log_timescale_increment) ) scaled_time = position * inv_timescales # Please note that this slightly differs from the published paper. # See a discussion here: # https://github.com/tensorflow/tensor2tensor/pull/177 # concat_dim_name = scaled_time.shape.dimension_names[1] concat_dim_name = channels.name signal = mtf.concat( [mtf.sin(scaled_time), mtf.cos(scaled_time)], concat_dim_name=concat_dim_name) if channels.size % 2 != 0: raise NotImplementedError("Odd channel size not implemented.") new_dims = [mtf.Dimension("expanded", 1) ] + length.shape.dims + channels.shape.dim signal = mtf.reshape(signal, mtf.Shape(new_dims)) return signal
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, max_optimized_variable_size=None, optimizer="adam", clip_gradients=True): """Creates an optimizer training op.""" global_step = tf.train.get_or_create_global_step() mesh = loss.mesh if init_lr: # Implements linear decay of the learning rate. learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) learning_rate = tf.train.polynomial_decay(learning_rate, global_step, num_train_steps, end_learning_rate=0.0, power=1.0, cycle=False) # Implements linear warmup. I.e., if global_step < num_warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. if num_warmup_steps: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) mtf_learning_rate = mtf.import_tf_tensor(mesh, learning_rate, []) else: if optimizer == "adam": raise ValueError("Adam does not have a default learning rate") learning_rate = None mtf_learning_rate = None # It is recommended that you use this optimizer for fine tuning, since this # is how the model was trained (note that the Adam m/v variables are NOT # loaded from init_checkpoint.) if optimizer == "adam": optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=mtf_learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) elif optimizer == "adafactor": optimizer = mtf_optimize.AdafactorOptimizer( learning_rate=learning_rate, min_dim_size_to_factor=32) else: raise ValueError("unknown optimizer") trainable_variables = mesh.graph.trainable_variables if max_optimized_variable_size: trainable_variables = [ t for t in trainable_variables if t.shape.size <= max_optimized_variable_size ] var_grads = mtf.gradients([loss], [v.outputs[0] for v in trainable_variables]) # This is how the model was pre-trained. if clip_gradients: (var_grads, _) = clip_by_global_norm(var_grads, clip_norm=mtf.constant(mesh, 1.0, dtype=tf.float32)) update_ops = optimizer.apply_grads(var_grads, trainable_variables) return learning_rate, update_ops
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): """Creates and returns an optimizer training op.""" global_step = tf.train.get_or_create_global_step() learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype) clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype) if inp_var_grads is None: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in mesh.graph.trainable_variables]) else: var_grads = inp_var_grads # Cast to full precision var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads] # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps end_step = params.get("lr_decay_end", params["train_steps"]) if params["lr_decay"] == "linear": learning_rate = tf.train.polynomial_decay( learning_rate, global_step, end_step, end_learning_rate=params["lr"] * 0.1, # Decrease to 10% of initial LR according to GPT-3 paper power=1.0, cycle=False) elif params["lr_decay"] == "cosine": learning_rate = tf.train.cosine_decay( learning_rate, global_step, end_step, alpha=0.1 # Alpha is min lr value as a fraction of init lr. ) if params["warmup_steps"] > 0: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32) dtype = variable_dtype.slice_dtype global_steps_float = tf.cast(global_steps_int, dtype) warmup_steps_float = tf.cast(warmup_steps_int, dtype) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = learning_rate * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype) learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate") mtf.scalar_summary("lr", learning_rate) if params["opt_name"].lower() == "adam": optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=params["weight_decay"], beta_1=params["beta1"], beta_2=params["beta2"], epsilon=params["epsilon"], exclude_from_weight_decay=["norm", "bias"], variable_dtype=variable_dtype) else: optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=params["lr"], decay_rate=params["weight_decay"], beta1=params["beta1"], epsilon1=params["ada_epsilon1"], epsilon2=params["ada_epsilon2"]) if params["use_tpu"]: optimizer = tf.tpu.CrossShardOptimizer(optimizer) if params["gradient_clipping"] is not None: (var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value) update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables) return learning_rate, update_ops, var_grads_fp
def _switch_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="switch_gating", num_microbatches=None): """Compute a switch top-1 gating with no-token-left behind behavior.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # Input perturbations if train and policy == "input_jitter": inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) # Top-k operation k_dim = mtf.Dimension("k", hparams.moe_switch_top_k) expert_gate, expert_index = mtf.top_k( raw_gates, reduced_dim=experts_dim, k_dim=k_dim) expert_mask = mtf.one_hot(expert_index, experts_dim) # LOAD BALANCING LOSS outer_batch_dim = inputs.shape[0] batch_dim = inputs.shape[1] group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum( -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Iteratively route tokens (no-token-left-behind). The idea is to route as # many tokens as possible to top-i before then trying top-(i+1). top_k_masks = mtf.split( expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_gates = mtf.split( expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_indices = mtf.split( expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size) # Tensors cumulative values over the iterative process. combine_tensor = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim]) cum_tokens = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim]) tokens_left_to_route = mtf.constant( inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim]) expert_capacity_float = float(expert_capacity_dim.size) for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates, top_k_indices): top_i_mask = mtf.reshape( top_i_mask, new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim]) # Operate only on the unrouted tokens. top_i_mask *= tokens_left_to_route # Record cumulative number of tokens to each expert across iterations. cumulative_tokens_in_expert = cum_tokens + mtf.cumsum( top_i_mask, group_size_dim) expert_overflow = mtf.to_float( mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float)) output_i_tokens = top_i_mask * expert_overflow # Update the cumulative tokens routed to each expert. cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim) tokens_left_to_route -= ( mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim)) # Combine-tensor for this iteration output_i_tokens_flat = mtf.reduce_sum( output_i_tokens, reduced_dim=experts_dim) position_in_expert = cumulative_tokens_in_expert - 1 top_i_combine_tensor = ( top_i_gate * output_i_tokens_flat * mtf.one_hot(top_i_index, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim)) combine_tensor += top_i_combine_tensor # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def beam_search(self, inputs, decode_length, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]. decode_length: an int32 mtf scalar. Maximum decode length. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather( beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def get_optimizer(loss, params, summary, inp_var_grads=None): """Creates and returns an optimizer training op.""" global_step = tf.train.get_or_create_global_step() # get global step mesh = loss.mesh # get mesh info from loss graph = mesh.graph # get graph info from mesh if inp_var_grads is None: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) else: var_grads = inp_var_grads learning_rate = tf.constant(value=params["lr"], shape=[], dtype=tf.float32) # grab lr param if params["lr_decay"] == "linear": learning_rate = tf.train.polynomial_decay( learning_rate, global_step, params["train_steps"], end_learning_rate=params["lr"] * 0.1, # decrease to 10% of initial LR according to GPT-3 paper power=1.0, cycle=False, ) elif params["lr_decay"] == "cosine": learning_rate = tf.train.cosine_decay( learning_rate, global_step, params["train_steps"], alpha=0.1, # alpha is min lr value as a fraction of init lr. ) if params["warmup_steps"] > 0: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = learning_rate * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ( 1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate summary.scalar("lr", learning_rate) if params["opt_name"].lower() == "adam": optimizer = mtf.optimize.AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=params["weight_decay"], beta_1=params["beta1"], beta_2=params["beta2"], epsilon=params["epsilon"], exclude_from_weight_decay=["norm", "bias"], ) else: optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=params["lr"], decay_rate=params["weight_decay"], beta1=params["beta1"], epsilon1=params["ada_epsilon1"], epsilon2=params["ada_epsilon2"], ) if params["gradient_clipping"] is not None: clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=tf.float32) (var_grads, _) = clip_by_global_norm(var_grads, clip_norm=clip_value) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) return learning_rate, update_ops