def widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=None): logger.debug("[input tensor] (name,shape):({},{})".format(id_hldr.name,id_hldr.shape)) logger.debug("[input tensor] (name,shape):({},{})".format(wt_hldr.name,wt_hldr.shape)) if float16: deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=float16, name="deep_embedding") else: fp32 = mtf.VariableDType(tf.float32,tf.float32,tf.float32) deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=fp32, name="deep_embedding") logger.debug("[output tensor] (name,shape):({},{})".format(deep_output.name,deep_output.shape)) expend_dim = mtf.Dimension('expend',size=1) embed_dim_one = mtf.Dimension('embed_dim_one',size=1) mask = mtf.reshape(wt_hldr, new_shape=[wt_hldr.shape.dims[0],wt_hldr.shape.dims[1],expend_dim], name='mask_reshape') logger.debug("[output tensor] (name,shape):({},{})".format(mask.name,mask.shape)) if float16: wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=float16, name="wide_embedding") else: fp32 = mtf.VariableDType(tf.float32,tf.float32,tf.float32) wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=fp32, name="wide_embedding") logger.debug("[output tensor] (name,shape):({},{})".format(wide_output.name,wide_output.shape)) wide_output = wide(wide_output,mask=mask,float16=float16) deep_output = deep(deep_output,mask=mask,float16=float16) result = mtf.add(wide_output,deep_output) result = mtf.reshape(result, new_shape=[wide_output.shape.dims[0],outdim],name='result_reshape') logger.debug("[output tensor] (name,shape):({},{})".format(result.name, result.shape)) return result
def get_variable_dtype(bf_16=True): # Trainable variable precision # Store checkpoints in master type, train in slice type, compute in activation type if bf_16: return mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: return mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)
def deep(x, mask, float16=None): x = mtf.einsum([x, mask], output_shape=x.shape.dims, name='deep_mul') logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) # 使用仿照mindspore中使用fp16来计算下面的dense x = mtf.cast(x, dtype=tf.float16) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim0', size=1024), name="deep-dense-0", reduced_dims=x.shape.dims[-2:], activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim1', size=512), name="deep-dense-1", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim2', size=256), name="deep-dense-2", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim3', size=128), name="deep-dense-3", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim4', size=1), name="deep-dense-4", variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) if float16: pass else: x = mtf.cast(x, dtype=tf.float32) return x
def test_sampling(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", 1) sequence_dim = mtf.Dimension("sequence", 1) inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) inputs = mtf.pad(inputs, [0, 3], sequence_dim.name) # create mask seq_len = params["n_ctx"] num_mem_kv = params.get('num_mem_kv', 0) length_dim = mtf.Dimension('sequence', seq_len) memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv) embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len) embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) other_features = {} other_features["attn_bias"] = biasmask_attn_weights( mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32)) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim params["mode"] = "predict" with not_raises(Exception): samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(), remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) samples = lowering.export_to_tf_tensor(samples)
def call_simple(self, inputs, targets, compute_loss, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), encoder_sequence_id=None, decoder_sequence_id=None): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType encoder_sequence_id: an optional Tensor decoder_sequence_id: an optional Tensor Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_output, encoder_loss = self.encoder.call_simple( inputs, None, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params) encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) if encoder_sequence_id is not None: encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) length_dim = targets.shape.dims[-1] shifted_targets = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) logits, loss = self.decoder.call_simple( shifted_targets, targets, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=decoder_sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params) if loss is not None and encoder_loss is not None: loss += encoder_loss return logits, loss
def call_simple(self, inputs, targets, compute_loss, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=None, position=None, encoder_output=None, encoder_sequence_id=None, shared_params=None): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training autoregressive models this should be equal to mtf.shift(targets, offset=1, dim=length_dim, wrap=False) targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType sequence_id: an optional Tensor position: an optional Tensor encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor shared_params: an optional dictionary Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ context = Context(mesh=inputs.mesh, batch_dims=inputs.shape.dims[:-1], length_dim=inputs.shape.dims[-1], model_dim=self.model_dim, variable_dtype=variable_dtype, mode=mode, autoregressive=self.autoregressive, losses=[] if compute_loss else None, sequence_id=sequence_id, position=position, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape) with tf.variable_scope(self.name): logits = self._call_internal(context, inputs, targets) if compute_loss: loss = mtf.add_n(context.losses) else: loss = None return logits, loss
def setUp(self): super(FlatKeyValueMemoryTest, self).setUp() self.graph = mtf.Graph() self.mesh = mtf.Mesh(self.graph, "mtf_mesh") self.variable_dtype = mtf.VariableDType(activation_dtype=tf.float32) self.addCleanup(mock.patch.stopall) self.initializer_mock = mock.MagicMock() random_normal_initializer_mock = mock.patch.object( tf, "random_normal_initializer").start() random_normal_initializer_mock.return_value = self.initializer_mock
def setUp(self): super(AdaptiveSoftmaxTest, self).setUp() self.graph = mtf.Graph() self.mesh = mtf.Mesh(self.graph, 'mtf_mesh') self.variable_dtype = mtf.VariableDType(activation_dtype=tf.float32) self.addCleanup(mock.patch.stopall) self.initializer_mock = mock.MagicMock() random_normal_initializer_mock = mock.patch.object( tf, 'random_normal_initializer').start() random_normal_initializer_mock.return_value = self.initializer_mock
def model_backbone(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ id_hldr, wt_hldr = features batch_dim = mtf.Dimension("batch", args_opt.batch_size) field_dim = mtf.Dimension("field", size=39) vocab_dim = mtf.Dimension("vocab_size", 200000) embed_dim = mtf.Dimension("embed_size", 80) outdim = mtf.Dimension("outdim", 1) id_hldr = mtf.import_tf_tensor( mesh, tf.reshape(id_hldr, [args_opt.batch_size, field_dim.size]), mtf.Shape([batch_dim, field_dim])) wt_hldr = mtf.import_tf_tensor( mesh, tf.reshape(wt_hldr, [args_opt.batch_size, field_dim.size]), mtf.Shape([batch_dim, field_dim])) if args_opt.fp16: float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16) # id_hldr=mtf.cast(id_hldr,dtype=tf.int32) wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16) else: float16 = None logits, embedding_table = network[args_opt.model](id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=float16) logits = mtf.cast(logits, dtype=tf.float32) embedding_table = mtf.cast(embedding_table, dtype=tf.float32) if labels is None: wide_loss = None deep_loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [args_opt.batch_size]), mtf.Shape([batch_dim])) wide_loss = mtf.layers.sigmoid_cross_entropy_with_logits( logits, labels) deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2 deep_loss = mtf.reduce_mean(wide_loss) + 8e-5 * deep_loss wide_loss = mtf.reduce_mean(wide_loss) return logits, wide_loss + deep_loss
def get_dummy_decoder_context(converter, batch=2, d_model=6, length=4, mode="incremental", initial_position=None, state=None, inputs=None): batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) # Set up a dummy model layer_stack = transformer.LayerStack(layers=[]) model = transformer.Unitransformer( d_model=d_model, input_vocab_size=10, # dummy values output_vocab_size=10, # dummy values autoregressive=True, max_length=length, layer_stack=layer_stack) if state is not None: state_mtf = converter.convert_np_array_to_mtf_tensor( state, dtype=tf.float32, dim_names=["batch", "length", "d_model"]) states = [state_mtf] else: states = None if initial_position: initial_position = mtf.constant(converter.mesh, initial_position, shape=mtf.Shape([batch_dim]), dtype=tf.int32) if inputs is not None: inputs = converter.convert_np_array_to_mtf_tensor( inputs, dim_names=["batch", "length"]) context = transformer.Context(model=model, mode=mode, states=states, new_states=[], mesh=converter.mesh, batch_dims=[batch_dim], length_dim=length_dim, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=1, inputs=inputs, initial_position=initial_position) return context
def test_model(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") seq_len = params["n_ctx"] batch_dim = mtf.Dimension("batch", 1) sequence_dim = mtf.Dimension("sequence", seq_len) features = { 'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32), 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) } # create mask num_mem_kv = params.get('num_mem_kv', 0) length_dim = mtf.Dimension('sequence', seq_len) memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv) embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len) embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) other_features = {} variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) other_features["attn_bias"] = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim with not_raises(Exception): logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) logits = lowering.export_to_tf_tensor(logits)
def get_variable_dtype(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.float32): """Datatypes to use for the run. Args: master_dtype: string, datatype for checkpoints keep this the same between training and eval/inference slice_dtype: string, datatype for variables in memory must be tf.float32 for training activation_dtype: string, datatype for activations less memory usage if tf.bfloat16 but possible numerical issues Returns: a mtf.VariableDtype """ return mtf.VariableDType(master_dtype=tf.as_dtype(master_dtype), slice_dtype=tf.as_dtype(slice_dtype), activation_dtype=tf.as_dtype(activation_dtype))
def testVariableOperations(self): var = mtf.Variable(self.mesh, "test_variable", self.ab_shape, mtf.VariableDType(tf.int32, tf.int32, tf.int32), initializer=tf.zeros_initializer(), trainable=True) self.assertEqual(var.splittable_dims, frozenset(["a", "b"])) self.assertEqual(var.unsplittable_dims, frozenset()) read_variable = mtf.ReadVariable(var) self.assertEqual(read_variable.splittable_dims, frozenset(["a", "b"])) self.assertEqual(read_variable.unsplittable_dims, frozenset()) assign = mtf.Assign([var], [self.x]) self.assertEqual(assign.splittable_dims, frozenset(["a", "b"])) self.assertEqual(assign.unsplittable_dims, frozenset()) depend = mtf.Depend(read_variable.outputs[0], [assign]) self.assertEqual(depend.splittable_dims, frozenset(["a", "b"])) self.assertEqual(depend.unsplittable_dims, frozenset())
def model_backbone(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", args_opt.batch_size) rows_dim = mtf.Dimension("rows_size", 224) cols_dim = mtf.Dimension("cols_size", 224) channel_dim = mtf.Dimension("image_channel", 3) classes_dim = mtf.Dimension(name='classesnum',size=args_opt.class_num) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [args_opt.batch_size, 224, 224, 3]), mtf.Shape( [batch_dim, rows_dim, cols_dim, channel_dim])) if args_opt.fp16: float16=mtf.VariableDType(tf.float16,tf.float16,tf.float16) else: float16=None logits = network[args_opt.model](x, classes_dim=classes_dim,float16=float16,batch_norm=False if 'vgg' in args_opt.model else True) logits = mtf.cast(logits,dtype=tf.float32) if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [args_opt.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"inputs": features, "labels": labels} sequence_length_dict = { "inputs": params["n_ctx"], "labels": params["n_ctx"] } params = add_mode_to_params(params, mode) batch_size = get_batch_size(params) batch_dim = mtf.Dimension("batch", batch_size) batch_dims = [batch_dim] feature_length = sequence_length_dict["inputs"] length_dim = mtf.Dimension("sequence", feature_length) mtf_features = {} for key, x in features_dict.items(): if x is not None: feature_shape = mtf.Shape(batch_dims + [length_dim]) if type(features_dict[key]) == dict: features_dict[key] = features_dict[key]["feature"] x = tf.cast(features_dict[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model other_features = {} memory_length_dim = mtf.Dimension("memory_length", length_dim.size) attn_bias = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None # Add attn_bias into mtf_features other_features["attn_bias"] = attn_bias # Define other Dimensions that we'll need inside the model embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) # We need this because gathering when both the args have the same dimension in them breaks things # This dim is specifically for the weights # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["inputs"] if params["remove_partial_sequences"] is None: params["remove_partial_sequences"] = False export = params.get("export", False) if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax']) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): mtf_samples, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=batch_dim, sequence_length=sequence_length_dict, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) return { "logits": logits, "loss": loss, "loss_batch": loss_batch } else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] logits = output_dict["logits"] else: # If we're not splitting into microbatches, return logits & loss as is if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Auto layout generation if params["auto_layout"]: auto_layout(graph, mesh_shape, logits, loss) if params["auto_layout_and_mesh_shape"]: auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype) # Log summaries to tensorboard mtf.scalar_summary("loss", loss) # Log gradients if in params if params["log_grads"] not in [None, False]: for g in var_grads: grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) else: # For now, we can only export fully-replicated tensors. # This has to be done before lowering or they will not be included in the graph mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) max_logits = mtf.argmax(logits, vocab_dim) del logits fully_replicated_mean_logits = mtf.anonymize(mean_logits) fully_replicated_max_logits = mtf.anonymize(max_logits) fully_replicated_loss_batch = mtf.anonymize(loss_batch) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step tf.logging.info(f"tf_update_ops: {tf_update_ops}") train_op = tf.group(tf_update_ops) else: tf_mean_logits = lowering.export_to_tf_tensor( fully_replicated_mean_logits) tf_max_logits = lowering.export_to_tf_tensor( fully_replicated_max_logits) tf_loss_batch = tf.to_float( lowering.export_to_tf_tensor(fully_replicated_loss_batch)) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # Evaluation metrics def _perplexity(loss): perplexity = tf.exp(loss) return tf.metrics.mean(perplexity) def _bits_per_byte(loss): bpb = loss * (0.29335 / math.log(2)) return tf.metrics.mean(bpb) def _metric_fn(tf_mean_logits, tf_loss_batch): mean_logits = tf.metrics.mean(tf_mean_logits) loss = tf.reduce_mean(tf_loss_batch) perp = _perplexity(loss) bpb = _bits_per_byte(loss) return { "mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb } def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): eos_token = params["eos_id"] answer_positions = tf.where( tf.math.not_equal(labels, eos_token)) correct_answers = tf.gather_nd( tf.math.equal(tf_max_logits, labels), answer_positions) accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) # I guess tf_loss_batch has z_loss and maybe other stuff added to it # so maybe this should be calculated separately in the future answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) log_perplexity = tf.metrics.mean(answer_loss) return { "lambada_acc": accuracy, "lambada_log_ppl": log_perplexity } eval_task = params["eval_task"] if eval_task == "lambada": eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) else: eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels): """Builds the UNet model graph, train op and eval metrics. Args: mesh: a MeshTensorflow.mesh object. mesh_impl: a mesh implementation, such as SimdMeshImpl and PlacementMeshImpl. dataset_str: a string of either train or eval. This is used for batch_norm. images: a laid out Tensor with shape [batch, x, y, num_channels] or [batch, x, y, z, num_channels]. labels: a laid out Tensor with shape [batch, x, y, num_classes] or [batch, x, y, z, num_classes]. Returns: Prediction and loss. """ is_training = (dataset_str == 'train') if dataset_str == 'train': batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train) else: assert dataset_str == 'eval' batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval) image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block) image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block) image_sx_dim = mtf.Dimension('image_sx_block', FLAGS.ct_resolution // FLAGS.image_nx_block) image_sy_dim = mtf.Dimension('image_sy_block', FLAGS.ct_resolution // FLAGS.image_ny_block) image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution) image_c_dim = mtf.Dimension('image_c', FLAGS.image_c) label_c_dim = mtf.Dimension('label_c', FLAGS.label_c) mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str) mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype) variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype) # Import input features. x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images), mtf_images_shape) x = mtf.cast(x, mtf_dtype) # Import ground truth labels. t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels), mtf_labels_shape) t = mtf.cast(t, mtf_dtype) # Transpose the blocks. if FLAGS.sampled_2d_slices: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, label_c_dim ]) else: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, label_c_dim ]) # Network. levels = [] all_bn_update_ops = [] # add levels with convolution or down-sampling for depth in range(FLAGS.network_depth): for n_conv in range(FLAGS.n_conv_per_block): if depth == 0 and n_conv == 0: # no dropout in 1st layer. dropout_keep_p = 1.0 else: dropout_keep_p = FLAGS.dropout_keep_p x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_down_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) levels.append(x) if depth < FLAGS.network_depth - 1: if FLAGS.sampled_2d_slices: x = mtf.layers.max_pool2d(x, ksize=(2, 2)) else: x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2)) # add levels with up-convolution or up-sampling for depth in range(FLAGS.network_depth - 1)[::-1]: x = deconv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, 'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1), variable_dtype, 'deconv_{}_0'.format(depth)) x = mtf.concat([x, levels[depth]], concat_dim_name='conv_{}_{}'.format( depth, FLAGS.n_conv_per_block - 1)) for n_conv in range(FLAGS.n_conv_per_block): x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_up_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) # no dropout in the final layer. if FLAGS.sampled_2d_slices: y = mtf.layers.conv2d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1), strides=(1, 1), padding='SAME', h_blocks_dim=image_nx_dim, w_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) else: y = mtf.layers.conv3d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME', d_blocks_dim=image_nx_dim, h_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) # use mtf.constant to make sure there is no CPU-side constants. def scalar(v, dtype): return mtf.constant(mesh, v, shape=[], dtype=dtype) argmax_t = mtf.argmax(t, label_c_dim) liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype) lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype) argmax_y = mtf.argmax(y, label_c_dim) lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype) # summary of class ratios. lesion_pred_ratio = mtf.reduce_mean(lesion_y) lesion_label_ratio = mtf.reduce_mean(lesion_t) # summary of accuracy. accuracy = mtf.reduce_mean( mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype)) # Cross-entropy loss. Up-weight the liver region. pixel_loss = mtf.layers.softmax_cross_entropy_with_logits( y, t, label_c_dim) pixel_weight = scalar(1, mtf_dtype) + \ liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \ lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight, mtf_dtype) loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight) # Dice loss y_prob = mtf.softmax(y, reduced_dim=label_c_dim) lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'), reduced_dim=mtf.Dimension('label_c', 1)) prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t, output_shape=mtf.Shape([batch_dim])) prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t, output_shape=mtf.Shape([batch_dim])) loss_dice_per_case = mtf.reduce_mean( scalar(-2, mtf_dtype) * prob_intersect / (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype))) loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum( prob_intersect) / (mtf.reduce_sum(prob_area_sum) + scalar(FLAGS.dice_epsilon, mtf_dtype)) loss_dice = (loss_dice_per_case + loss_dice_global) * scalar( 0.5, mtf_dtype) loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar( 1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen intersect = mtf.reduce_sum(lesion_y * lesion_t, output_shape=mtf.Shape([batch_dim])) area_sum = mtf.reduce_sum(lesion_y + lesion_t, output_shape=mtf.Shape([batch_dim])) # summary of dice. dice_per_case = mtf.reduce_mean( scalar(2, mtf_dtype) * intersect / (area_sum + scalar(0.000001, mtf_dtype))) dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / ( mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype)) eval_metrics = { 'lesion_pred_ratio': lesion_pred_ratio, 'lesion_label_ratio': lesion_label_ratio, 'accuracy_of_all_classes': accuracy, 'lesion_dice_per_case': dice_per_case, 'lesion_dice_global': dice_global, 'loss_xen': loss_xen, 'loss_dice': loss_dice, 'loss_dice_per_case': loss_dice_per_case, 'loss_dice_global': loss_dice_global, } if FLAGS.sampled_2d_slices: y_prob_downsampled = mtf.layers.avg_pool2d( y_prob, ksize=(FLAGS.pred_downsample, ) * 2) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool2d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 2) else: y_prob_downsampled = mtf.layers.avg_pool3d( y_prob, ksize=(FLAGS.pred_downsample, ) * 3) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool3d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 3) liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c') lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c') preds = [ mtf.reduce_sum(liver_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)), mtf.reduce_sum(lesion_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)) ] if FLAGS.output_ground_truth: preds.append( mtf.reduce_sum(lesion_gt_downsampled, reduced_dim=mtf.Dimension('label_c', 1))) preds.extend([intersect, area_sum]) return preds, loss, eval_metrics, all_bn_update_ops
def variable_dtype(self): return mtf.VariableDType(tf.as_dtype(self._hparams.master_dtype), tf.as_dtype(self._hparams.slice_dtype), tf.as_dtype(self._hparams.activation_dtype))
def call_simple(self, inputs, targets, compute_loss, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), encoder_sequence_id=None, decoder_sequence_id=None, decoder_subsequence_id=None, encoder_position=None, decoder_position=None, num_microbatches=1): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. This class inherits the trnasformer.Bitransformer with one difference. The encoder is Funnel Transformer. So the length dimension is reduced. The decoder needs to use the updated `encoder_sequence_id`. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType encoder_sequence_id: an optional Tensor decoder_sequence_id: an optional Tensor decoder_subsequence_id: an optional Tensor encoder_position: an optional Tensor decoder_position: an optional Tensor num_microbatches: integer - greater than one if the step has been serialized into multiple microbatches to save memory. Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ # encoder_sequene_id and decoder_sequence_id are used to delineate packed # examples but are also necessary to indicate padding where sequence_id==0. # If they are absent, then we assume that padding is indicated by zeros in # the inputs/targets, and we make up sequence_id tensors to indicate this. if encoder_sequence_id is None: encoder_sequence_id = mtf.minimum(inputs, 1) if decoder_sequence_id is None: decoder_sequence_id = mtf.minimum(targets, 1) encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_output, encoder_loss = self.encoder.call_simple( inputs, None, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, position=encoder_position, shared_params=shared_params, layer_outputs=encoder_layer_outputs, num_microbatches=num_microbatches) encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) # The sequence_id is updated inside the layer_stack due to pooling. So we # need to use the updated sequence_id stored in the context. encoder_sequence_id = self.encoder.layer_stack.context.sequence_id encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) logits, loss = self.decoder.call_simple( transformer.autoregressive_inputs(targets, sequence_id=decoder_sequence_id), targets, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=decoder_sequence_id, subsequence_id=decoder_subsequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs), position=decoder_position, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, num_microbatches=num_microbatches) if loss is not None and encoder_loss is not None: loss += encoder_loss return logits, loss
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, sampling_keep_top_k=-1, decode_length_multiplier=1.5, decode_length_constant=10, max_decode_length=None): """Sampling or beam search for Funnel Transformer. Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sampling_keep_top_k: a value between 1 and vocab_size used to sample from only the k most likely logits. Set to -1 to sample from all logits. decode_length_multiplier: a float decode_length_constant: a float max_decode_length: an optional integer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) # The sequence_id is updated inside the layer_stack due to pooling. So we # need to use the updated sequence_id stored in the context. encoder_sequence_id = self.encoder.layer_stack.context.sequence_id encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] if max_decode_length is None: decode_length_dim = length_dim else: decode_length_dim = mtf.Dimension("length", max_decode_length) if beam_size == 1: ids_shape = mtf.Shape(batch_dims + [decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, sampling_keep_top_k=sampling_keep_top_k, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length( inputs), shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") if sampling_keep_top_k != -1: raise ValueError( "don't know how to beam search with top-k value other than -1." ) # beam search beam_dim = mtf.Dimension("beam", beam_size) ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=inputs, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def beam_search(self, inputs, decode_length, dst_attributes=None, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None, z=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim].# decode_length: an int32 mtf scalar. Maximum decode length. attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim] ([<batch_dims>] [<batch_dims>, beam_dim]). variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ attributes = dst_attributes if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, tf.int32) initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None if self.input_full_attention: # This only makes sense in the case of beam search with given partial # sequences, which is not yet implemented. # TODO(noam): implement raise NotImplementedError( "Beam search for language models not yet implemented") else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, step_num - 1, length_dim) else: attributes_this_step = None context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=step_num, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def sample_autoregressive(self, partial_sequences, dst_attributes=None, stop_at_token=1, max_steps=None, temperature=0.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, z=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). The dst_attributes represents the destination attributes in which we want to generate sequences. Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. Returns: a Tensor with shape [<batch_dims>, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences attributes = dst_attributes batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal( inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None length_range = mtf.range(inputs.mesh, length_dim, tf.int32) if self.input_full_attention: read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs, attributes=attributes, z=z) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] partial_sequences_eos_count = 0 else: initial_states = context_first_part.new_states partial_sequences_eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, position - 1, length_dim) else: attributes_this_step = None # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim)) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, 0)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs
def create_dummy_model(mesh, shapes, n_blocks=2, block_param_size_str="2_2", block_repeat_size_str="1_1"): """Creates a dummy model and layer stack with 4-dimensional input.""" assert len(shapes) == 4 outer_batch_size, batch_size, length, d_model = shapes batch_dim = mtf.Dimension("batch", batch_size) outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) length_dim = mtf.Dimension("length", length) block_param_size = list(map(int, block_param_size_str.split("_"))) block_repeat_size = list(map(int, block_repeat_size_str.split("_"))) sublayers_initial = [ transformer.sublayer_dropout, ] sublayers_per_layer = [ transformer.sublayer_rms_norm, transformer.sublayer_call_layer, transformer.sublayer_dropout, transformer.sublayer_residual, ] sublayers_final = [ transformer.sublayer_rms_norm, transformer.sublayer_dropout, ] submodules = [ transformer_layers.SelfAttention(), transformer_layers.DenseReluDense() ] n_sublayers = np.array(block_param_size).prod() layers = submodules * n_sublayers layer_stack = funnel_transformer.FunnelTransformerLayerStack( layers=layers, n_blocks=n_blocks, block_param_size=block_param_size, block_repeat_size=block_repeat_size, sublayers_initial=sublayers_initial, sublayers_per_layer=sublayers_per_layer, sublayers_final=sublayers_final) model = transformer.Unitransformer(input_vocab_size=10, output_vocab_size=10, autoregressive=False, max_length=8, d_model=d_model, layer_stack=layer_stack) context = transformer.Context(model=model, mesh=mesh, batch_dims=[batch_dim, outer_batch_dim], length_dim=length_dim, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=mtf.ones(mesh, mtf.Shape([length_dim ])), position=mtf.range(mesh, length_dim, dtype=tf.int32)) return layer_stack, context
def call_simple(self, inputs, targets, compute_loss, attributes=None, codeprefixedtargets=None, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), encoder_sequence_id=None, decoder_sequence_id=None, decoder_subsequence_id=None, encoder_position=None, decoder_position=None): # attributes=None for debugging? """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType encoder_sequence_id: an optional Tensor decoder_sequence_id: an optional Tensor decoder_subsequence_id: an optional Tensor encoder_position: an optional Tensor decoder_position: an optional Tensor Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_output, encoder_loss = self.encoder.call_simple( inputs, None, compute_loss, attributes=attributes, mode=mode, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, position=encoder_position, shared_params=shared_params, layer_outputs=encoder_layer_outputs) encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) if encoder_sequence_id is not None: encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) if self.cut_cross_attention: z = mtf.gather( encoder_output, mtf.zeros(inputs.mesh, mtf.Shape(inputs.shape[:-1] + [encoder_output.shape[-1]]), dtype=tf.int32), encoder_output.shape[-2]) encoder_output = None else: z = None if codeprefixedtargets: decoder_input = shift_targets_no_offset( codeprefixedtargets ) # shift_attribute_targets(targets, attribute_id=codeprefixedtargets), # codeprefixedtargets # mtf.zeros_like(targets) else: decoder_input = shift_targets(targets) # shift_targets = mtf.shift(targets, offset=-1, dim=targets.shape.dims[-1], wrap=False) # Remove token preceding ":" # shift_targets = mtf.shift(shift_targets, offset=1, dim=targets.shape.dims[-1], wrap=False) logits, loss = self.decoder.call_simple( decoder_input, targets, compute_loss, attributes=attributes, mode=mode, variable_dtype=variable_dtype, sequence_id=decoder_sequence_id, subsequence_id=decoder_subsequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs), position=decoder_position, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, z=z) # Maybe sample_autoregressive here ? if loss is not None and encoder_loss is not None: loss += encoder_loss return logits, loss
def run(tpu_job_name, data_dir, master_dtype, slice_dtype, activation_dtype, tpu, gcp_project, tpu_zone, autostack, model_dir, mode=gin.REQUIRED, iterations_per_loop=gin.REQUIRED, save_checkpoints_steps=gin.REQUIRED, eval_steps=gin.REQUIRED, train_steps=gin.REQUIRED, batch_size=gin.REQUIRED, text2self=gin.REQUIRED, dataset=gin.REQUIRED): """Run training/eval/inference. Args: tpu_job_name: string, name of TPU worker binary data_dir: string, data_dir for TensorFlow Datasets master_dtype: string, datatype for checkpoints slice_dtype: string, datatype for variables in memory activation_dtype: string, datatype for activations tpu: string, the Cloud TPU to use for training gcp_project: string, project name for the Cloud TPU-enabled project tpu_zone: string, GCE zone where the Cloud TPU is located in autostack: boolean, internally combine variables model_dir: string, estimator model_dir mode: string, train/evaluate/infer iterations_per_loop: integer, steps per train loop save_checkpoints_steps: integer, steps per checkpoint eval_steps: integer, number of evaluation steps train_steps: Total number of training steps. batch_size: Mini-batch size for the training. Note that this is the global batch size and not the per-shard batch. text2self: Whether to train a language model (True) or encoder-decoder text-to-text model (False). dataset: TensorFlow Datasets dataset name. """ cluster = tf.contrib.cluster_resolver.TPUClusterResolver( tpu if (tpu) else "", zone=tpu_zone, project=gcp_project) my_tpu_config = tpu_config.TPUConfig( tpu_job_name=tpu_job_name, iterations_per_loop=iterations_per_loop, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST, ) run_config = tpu_config.RunConfig( cluster=cluster, model_dir=model_dir, save_checkpoints_steps=save_checkpoints_steps, tpu_config=my_tpu_config) dataset = transformer_dataset.TokenizedTFDSDataset(dataset, text2self=text2self, data_dir=data_dir or None) output_encoder = dataset.encoders["targets"] if text2self: input_encoder = output_encoder else: input_encoder = dataset.encoders["inputs"] transformer_model = model( input_vocab_size=transformer_dataset.padded_vocab_size( input_encoder.vocab_size, 128), output_vocab_size=transformer_dataset.padded_vocab_size( output_encoder.vocab_size, 128), text2self=text2self) mesh_shape = mtf.convert_to_shape(gin.query_parameter("model.mesh_shape")) layout_rules = mtf.convert_to_layout_rules( gin.query_parameter("model.layout")) # Data-types used for variables and activations # See comments in the FLAGS master_dtype = tf.as_dtype(master_dtype) if slice_dtype: slice_dtype = tf.as_dtype(slice_dtype) elif not tpu or mode == "train": slice_dtype = tf.float32 else: slice_dtype = tf.bfloat16 if activation_dtype: activation_dtype = tf.as_dtype(activation_dtype) else: activation_dtype = tf.bfloat16 if tpu else tf.float32 variable_dtype = mtf.VariableDType(master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=activation_dtype) length_from_config = gin.query_parameter( "model.length") or gin.query_parameter("model.max_length") model_fn = tpu_estimator_model_fn(transformer_model=transformer_model, model_dir=model_dir, use_tpu=tpu, mesh_shape=mesh_shape, layout_rules=layout_rules, text2self=text2self, variable_dtype=variable_dtype, batch_size=batch_size, length=length_from_config, autostack=autostack) estimator = tpu_estimator.TPUEstimator(model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size, use_tpu=tpu, export_to_tpu=False, params={}) def input_fn(params): del params return dataset.load(batch_size=batch_size, length=length_from_config, train=(mode == "train"), pack=True) if mode == "train": estimator.train(input_fn=input_fn, max_steps=train_steps) elif mode == "evaluate": estimator.evaluate( input_fn=input_fn, steps=eval_steps, ) elif mode == "infer": decode_from_file(estimator, batch_size=batch_size, length=length_from_config, inputs_encoder=dataset. encoders["targets" if text2self else "inputs"], targets_encoder=dataset.encoders["targets"], text2self=text2self) else: raise ValueError("unknown mode %s - must be train/evaluate/infer" % mode)
id_hldr = mtf.import_tf_tensor( mesh, np.random.randn(batch_dim.size, field_dim.size).astype(np.int32), shape=[batch_dim, field_dim]) wt_hldr = mtf.import_tf_tensor( mesh, np.random.randn(batch_dim.size, field_dim.size).astype(np.float32), shape=[batch_dim, field_dim]) label = mtf.import_tf_tensor(mesh, np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]).astype(np.float32), shape=[batch_dim]) fp16 = True if fp16: float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16) id_hldr = mtf.cast(id_hldr, dtype=tf.int32) wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16) else: float16 = None result, embedding_table = widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16) # label = mtf.reshape(label,new_shape=[batch_dim, outdim]) # output = mtf.layers.softmax_cross_entropy_with_logits(result, label,vocab_dim=outdim) # result = mtf.sigmoid(result) # result = -(label*mtf.log(result)+(1-label)*mtf.log(1-result)) # result = mtf.reduce_sum(result) result = mtf.cast(result, dtype=tf.float32) embedding_table = mtf.cast(embedding_table, dtype=tf.float32) probability = mtf.sigmoid(result)
def beam_search(self, inputs, decode_length, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]. decode_length: an int32 mtf scalar. Maximum decode length. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather( beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def sample_autoregressive(self, partial_sequences, stop_at_token=1, max_steps=None, temperature=1.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, length_dim] """ del max_steps # TODO(noam): implement if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states] else: initial_states = context_first_part.new_states def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) is_done = past_end if stop_at_token is not None: has_eos = mtf.reduce_any( mtf.equal(ids, stop_at_token), reduced_dim=length_dim) is_done = mtf.logical_or(is_done, has_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position return outputs
def sample_autoregressive( partial_sequences, other_features, params, stop_at_token=50256, max_steps=None, temperature=0.9, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, bos_id=50256, ): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. bos_id: beginning of sequence id Returns: a Tensor with shape [<batch_dims>, length_dim] """ inputs = partial_sequences # Partial sequences to fill in batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] padding_id = params.get("padding_id", 0) slow_sampling = params.get("slow_sampling", False) initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts length_range = mtf.range(inputs.mesh, length_dim, tf.int32) input_full_attention = True # for now hardcode this to true bc lazy if input_full_attention: # Vanilla autoregressive model - each position can see previous positions. # Think this feeds in to the loop fn and tells each position where it can attend to? read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range # Builds context to pass around internally # The 'first part' context records initial states of k / v / x if not slow_sampling: context_first_part = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) with tf.variable_scope("gpt2"): logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part) if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states ] else: initial_states = context_first_part.new_states else: initial_states = [] if not has_partial_sequences: partial_sequences_eos_count = 0 if stop_at_token is not None: partial_sequences_eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum(mtf.to_int32( mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop(cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # Remove partial sequences from outputs partial_length = mtf.reduce_sum(mtf.to_int32( mtf.not_equal(partial_sequences, padding_id)), reduced_dim=length_dim) outputs = mtf.dynamic_shift(outputs, -partial_length, length_dim, wrap=False) return outputs
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, decode_length_multiplier=1.5, decode_length_constant=10): """Sampling or beam search. TODO(noam): should we make the output length dimension different from the input length dimension? Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. decode_length_multiplier: a float decode_length_constant: a float Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output) encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) if beam_size == 1: ids_shape = inputs.shape partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") # beam search beam_dim = mtf.Dimension("beam", beam_size) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config use_tpu = FLAGS.tpu global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) model = transformer.Bitransformer( encoder_layer_stack=layer_stack(include_encdec_attention=False), decoder_layer_stack=layer_stack(include_encdec_attention=True), encoder_d_model=FLAGS.d_model, decoder_d_model=FLAGS.d_model, input_vocab_size=transformer_dataset.padded_vocab_size( transformer_dataset.inputs_vocab_size(FLAGS.dataset)), output_vocab_size=transformer_dataset.padded_vocab_size( transformer_dataset.targets_vocab_size(FLAGS.dataset)), max_length=FLAGS.max_length, shared_embedding=False, shared_embedding_and_softmax_weights=True, label_smoothing=FLAGS.label_smoothing, layout=FLAGS.layout, mesh_shape=FLAGS.mesh_shape) inputs = import_feature(features, mesh, "inputs") # Data-types used for variables and activations # See comments in the FLAGS master_dtype = tf.as_dtype(FLAGS.master_dtype) if FLAGS.slice_dtype: slice_dtype = tf.as_dtype(FLAGS.slice_dtype) elif not FLAGS.tpu or FLAGS.mode == "train": slice_dtype = tf.float32 else: slice_dtype = tf.bfloat16 if FLAGS.activation_dtype: activation_dtype = tf.as_dtype(FLAGS.activation_dtype) else: activation_dtype = tf.bfloat16 if FLAGS.tpu else tf.float32 variable_dtype = mtf.VariableDType(master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=activation_dtype) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: mtf_samples = model.decode( inputs, variable_dtype=variable_dtype, beam_size=FLAGS.beam_size, alpha=FLAGS.alpha, temperature=FLAGS.temperature) mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = { "outputs": outputs } return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = import_feature(features, mesh, "targets") anon_targets = mtf.anonymize(targets) logits, loss = model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=variable_dtype, encoder_sequence_id=import_feature(features, mesh, "inputs_segmentation"), decoder_sequence_id=import_feature( features, mesh, "targets_segmentation"), encoder_position=import_feature(features, mesh, "inputs_position"), decoder_position=import_feature(features, mesh, "targets_position") ) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def padded_neg_log_perplexity(logits, labels): weights = tf.to_float(tf.not_equal(labels, 0)) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) return {"neg_log_perplexity": tf.metrics.mean(-xent, weights)} labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)