def model_fn(features, labels, mode, params): del params # unused with variable_scope.variable_scope('m', reuse=variable_scope.AUTO_REUSE): w = variable_scope.get_variable('W', shape=[1000, 10]) logits = math_ops.matmul(features, w) loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) if mode == model_fn_lib.ModeKeys.TRAIN: optimizer = training.RMSPropOptimizer(learning_rate=0.01) optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) train_op = optimizer.minimize(loss, training.get_global_step()) return tpu_estimator.TPUEstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=loss, train_op=train_op, ) elif mode == model_fn_lib.ModeKeys.EVAL: def metric_fn(labels, logits): labels = math_ops.cast(labels, dtypes.int64) logging.info('LABELS %s %s', labels, logits) return { 'recall@1': metrics_lib.recall_at_k(labels, logits, 1), 'recall@5': metrics_lib.recall_at_k(labels, logits, 5), } loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) eval_metrics = (metric_fn, [labels, logits]) return tpu_estimator.TPUEstimatorSpec(mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params): global NUM_CLASSES assert NUM_CLASSES is not None model = tf.keras.Sequential([ hub.KerasLayer( "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4", output_shape=[2048], trainable=False), tf.keras.layers.Dense(NUM_CLASSES, activation="softmax") ]) optimizer = None if mode == tf.estimator.ModeKeys.TRAIN: if not params["use_compat"]: optimizer = tf.optimizers.Adam(params["learning_rate"]) else: optimizer = tf.compat.v1.train.AdamOptimizer( params["learning_rate"]) if params["use_tpu"]: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) with tf.GradientTape() as tape: logits = model(features) if mode == tf.estimator.ModeKeys.PREDICT: preds = {"predictions": logits} return tpu_estimator.TPUEstimatorSpec(mode, predictions=preds) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)( labels, logits) if mode == tf.estimator.ModeKeys.EVAL: return tpu_estimator.TPUEstimatorSpec(mode, loss=loss) def train_fn(use_compat): assert optimizer is not None gradient = tape.gradient(loss, model.trainable_variables) global_step = tf.compat.v1.train.get_global_step() apply_grads = tf.no_op( ) # Does Nothing. Initialization only. None would also work if not use_compat: update_global_step = tf.compat.v1.assign(global_step, global_step + 1, name='update_global_step') with tf.control_dependencies([update_global_step]): apply_grads = optimizer.apply_gradients( zip(gradient, model.trainable_variables)) else: apply_grads = optimizer.apply_gradients(zip( gradient, model.trainable_variables), global_step=global_step) return apply_grads if mode == tf.estimator.ModeKeys.TRAIN: return tpu_estimator.TPUEstimatorSpec(mode, loss=loss, train_op=train_fn( params['use_compat']))
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels del features mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") with mtf.utils.outside_all_rewrites(): fsum = benchmark_model(mesh) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_err = tf.to_float(lowering.export_to_tf_tensor(fsum)) with mtf.utils.outside_all_rewrites(): return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err)
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels del features mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") with mtf.utils.outside_all_rewrites(): field = nbody_model(mesh) batch_dim, x_dim, y_dim, z_dim = field.shape x_dim_nosplit = mtf.Dimension("nx_nosplit", FLAGS.cube_size) y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size) # Until we implement distributed outputs, we only return one example field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size - 1]) field_slice = mtf.reshape( field_slice, [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim]) #field_slice = field lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice)) with mtf.utils.outside_all_rewrites(): return tpu_estimator.TPUEstimatorSpec(mode, predictions={'field': tf_field})
def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"inputs": features, "labels": labels} sequence_length_dict = { "inputs": params["n_ctx"], "labels": params["n_ctx"] } params = add_mode_to_params(params, mode) batch_size = get_batch_size(params) batch_dim = mtf.Dimension("batch", batch_size) batch_dims = [batch_dim] feature_length = sequence_length_dict["inputs"] length_dim = mtf.Dimension("sequence", feature_length) mtf_features = {} for key, x in features_dict.items(): if x is not None: feature_shape = mtf.Shape(batch_dims + [length_dim]) if type(features_dict[key]) == dict: features_dict[key] = features_dict[key]["feature"] x = tf.cast(features_dict[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model other_features = {} memory_length_dim = mtf.Dimension("memory_length", length_dim.size) attn_bias = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None # Add attn_bias into mtf_features other_features["attn_bias"] = attn_bias # Define other Dimensions that we'll need inside the model embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) # We need this because gathering when both the args have the same dimension in them breaks things # This dim is specifically for the weights # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["inputs"] if params["remove_partial_sequences"] is None: params["remove_partial_sequences"] = False export = params.get("export", False) if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax']) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): mtf_samples, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=batch_dim, sequence_length=sequence_length_dict, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) return { "logits": logits, "loss": loss, "loss_batch": loss_batch } else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] logits = output_dict["logits"] else: # If we're not splitting into microbatches, return logits & loss as is if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Auto layout generation if params["auto_layout"]: auto_layout(graph, mesh_shape, logits, loss) if params["auto_layout_and_mesh_shape"]: auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype) # Log summaries to tensorboard mtf.scalar_summary("loss", loss) # Log gradients if in params if params["log_grads"] not in [None, False]: for g in var_grads: grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) else: # For now, we can only export fully-replicated tensors. # This has to be done before lowering or they will not be included in the graph mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) max_logits = mtf.argmax(logits, vocab_dim) del logits fully_replicated_mean_logits = mtf.anonymize(mean_logits) fully_replicated_max_logits = mtf.anonymize(max_logits) fully_replicated_loss_batch = mtf.anonymize(loss_batch) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step tf.logging.info(f"tf_update_ops: {tf_update_ops}") train_op = tf.group(tf_update_ops) else: tf_mean_logits = lowering.export_to_tf_tensor( fully_replicated_mean_logits) tf_max_logits = lowering.export_to_tf_tensor( fully_replicated_max_logits) tf_loss_batch = tf.to_float( lowering.export_to_tf_tensor(fully_replicated_loss_batch)) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # Evaluation metrics def _perplexity(loss): perplexity = tf.exp(loss) return tf.metrics.mean(perplexity) def _bits_per_byte(loss): bpb = loss * (0.29335 / math.log(2)) return tf.metrics.mean(bpb) def _metric_fn(tf_mean_logits, tf_loss_batch): mean_logits = tf.metrics.mean(tf_mean_logits) loss = tf.reduce_mean(tf_loss_batch) perp = _perplexity(loss) bpb = _bits_per_byte(loss) return { "mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb } def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): eos_token = params["eos_id"] answer_positions = tf.where( tf.math.not_equal(labels, eos_token)) correct_answers = tf.gather_nd( tf.math.equal(tf_max_logits, labels), answer_positions) accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) # I guess tf_loss_batch has z_loss and maybe other stuff added to it # so maybe this should be calculated separately in the future answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) log_perplexity = tf.metrics.mean(answer_loss) return { "lambada_acc": accuracy, "lambada_log_ppl": log_perplexity } eval_task = params["eval_task"] if eval_task == "lambada": eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) else: eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: dictionary where keys are strings like "inputs" and "targets" and the values are the actual values of "inputs". See TPUEstimator's docs for more information labels: ignored argument mode: a tf.estimator.ModeKeys params: dictionary containing the key "context" config: ignored argument Returns: a TPUEstimatorSpec """ del labels, config global_step = tf.train.get_global_step() if use_tpu and "context" in params: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) # deprecated mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None # deprecated mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) mtf_features = {} for key, x in features.items(): outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) # Some auxiliary features may have been generated in packing. # The names of these new features are of the form # "<original_feature_name>_<suffix>", e.g. "inputs_segmentation". # We look up the lengths based on the original feature name, without # the "_<suffix>". feature_length = sequence_length[key.split("_")[0]] length_dim = mtf.Dimension("length", feature_length) ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)] if ensemble_inputs else []) feature_shape = mtf.Shape(ensemble_dims + [outer_batch_dim, batch_dim, length_dim]) x = tf.cast(features[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) if not use_tpu: tf.logging.info("feature %s : %s" % (key, x)) x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=10) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if key == "targets" or key == "codeprefixedtargets" or key == "controlcode": anon_targets = mtf.anonymize(mtf_features[key]) if mode == tf.estimator.ModeKeys.PREDICT: def _feature_shape(key): feature_length = sequence_length[key.split("_")[0]] return mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", feature_length) ]) mtf_features = { k: mtf.reshape(v, _feature_shape(k)) for k, v in six.iteritems(mtf_features) } inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None if predict_fn: mtf_samples = predict_fn(model=transformer_model, features=mtf_features, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Unitransformer): # pad so that there is enough room for the targets inputs = mtf.pad(inputs, [0, sequence_length["targets"]], length_dim.name) mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype(), remove_partial_sequences=True) elif isinstance(transformer_model, Bitransformer_ll): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} # When exporting a model, we need to communicate to TF-Serving that # master variables need to be copied to their slave slice variables. # Estimator uses a Scaffold's "local_init_op" for this purpose, so we # augment the default "local_init_op" here. # # The "ready_op" is also constructed here to ensure the variables # initialized by "local_init_op" are the same ones checked by "ready_op". # # WARNING: Any variables created outside of this model_fn() # (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor # checked by these ops. def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) def logits_and_loss(mtf_features): """Compute logits and loss. Args: mtf_features: a dictionary Returns: logits: a mtf.Tensor loss: a mtf.Tensor """ if model_type == "lm": # TOTRY Adapt that to our case if "inputs" in mtf_features: mtf_features = _dynamic_text2self(mtf_features) _, _, length_dim = mtf_features["targets"].shape inputs = mtf.shift(mtf_features["targets"], offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if control_codes: codeprefixedtargets = mtf_features["codeprefixedtargets"] else: codeprefixedtargets = None if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": if control_codes: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "codeprefixedtargets_segmentation", None), decoder_subsequence_id=mtf_features.get( "codeprefixedtargets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "codeprefixedtargets_position", None), ) else: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), decoder_subsequence_id=mtf_features.get( "targets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") if isinstance(transformer_model, Bitransformer_ll): if cycle_consistency_loss: logits_ae, l_ae = transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None with gin.config_scope('training'): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # mtf_samples = mtf.anonymize(mtf_samples) outputs = mtf_samples logits_cycle, l_cycle = transformer_model.call_simple( inputs=outputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle return logits_cycle, loss_ae_cycle else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), num_microbatches=num_microbatches, **position_kwargs) if mode == tf.estimator.ModeKeys.TRAIN: num_microbatches = serialize_num_microbatches( batch_dim, sequence_length, mesh_shape, layout_rules) if num_microbatches > 1: def serialized_fn(mtf_features): return { "loss": (logits_and_loss(mtf_features)[1] / num_microbatches) } var_grads, loss_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = loss_dict["loss"] else: loss = logits_and_loss(mtf_features)[1] var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if tpu_summaries: mtf.scalar_summary("loss", loss) if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule if isinstance(variable_filter, str): pattern = re.compile(variable_filter) variable_filter_fn = lambda v: pattern.search(v.name) elif variable_filter is None: variable_filter_fn = lambda v: True elif callable(variable_filter): variable_filter_fn = variable_filter else: raise ValueError( "variable_filter must be None, a string, or a callable function" ) trainable_vars = [ v for v in graph.trainable_variables if variable_filter_fn(v) ] trainable_var_grads = [ g for g, v in zip(var_grads, graph.trainable_variables) if variable_filter_fn(v) ] if len(trainable_vars) != len(graph.trainable_variables): tf.logging.info("Variables being trained:") tf.logging.info([v.name for v in trainable_vars]) tf.logging.info("Variables not being trained:") tf.logging.info([ v.name for v in graph.trainable_variables if not variable_filter_fn(v) ]) update_ops = optimizer(learning_rate=learning_rate).apply_grads( trainable_var_grads, trainable_vars) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() if tpu_summaries: # has to be outside of # with mtf.utils.outside_all_rewrites() host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None with mtf.utils.outside_all_rewrites(): if init_checkpoint: ckpt_vars = { v for v, _ in tf.train.list_variables(init_checkpoint) } global_vars = {v.op.name for v in tf.global_variables()} restore_vars = ckpt_vars.intersection(global_vars) tf.logging.info("Initializing variables from %s:", init_checkpoint) tf.logging.debug("\n".join(sorted(restore_vars))) tf.logging.info("Variables in %s but not in graph:", init_checkpoint) tf.logging.info("\n".join(sorted(ckpt_vars - global_vars))) tf.logging.info("Variables in graph but not in %s:", init_checkpoint) tf.logging.info("\n".join(sorted(global_vars - ckpt_vars))) tf.train.init_from_checkpoint(init_checkpoint, {v: v for v in restore_vars}) # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True, include_step_in_filename=False) if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) elif mode == tf.estimator.ModeKeys.EVAL: logits, loss = logits_and_loss(mtf_features) anon_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) tf_loss = tf.cast(tf_loss, tf.float32) tf_logits = tf.cast(lowering.export_to_tf_tensor(anon_logits), tf.float32) def simple_metrics(logits, labels): """Simple metrics for teacher-forced eval.""" weights = tf.cast(tf.not_equal(labels, 0), tf.float32) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) predictions = tf.cast(tf.argmax(logits, axis=-1), labels.dtype) token_correct = tf.cast(tf.equal(predictions, labels), tf.float32) * weights sequence_correct = tf.to_float( tf.equal(tf.reduce_sum(token_correct, -1), tf.reduce_sum(weights, -1))) sequence_weights = tf.to_float( tf.not_equal(tf.reduce_sum(weights, -1), 0)) return { "neg_log_perplexity": tf.metrics.mean(-xent, weights), "token_accuracy": tf.metrics.mean(token_correct, weights), "sequence_accuracy": tf.metrics.mean(sequence_correct, sequence_weights) } labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (simple_metrics, [tf_logits, labels]) with mtf.utils.outside_all_rewrites(): restore_hook = mtf.MtfRestoreHook(lowering) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.use_tpu: ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if FLAGS.optimizer == 'Adafactor': optimizer = mtf.optimize.AdafactorOptimizer() else: assert FLAGS.optimizer == 'SGD' optimizer = mtf.optimize.SgdOptimizer(learning_rate=FLAGS.lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {'mean_logits': mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def vae_model_fn(features, labels, mode, params): # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step H = W = params["dataset"]["image_size"] # TODO: check equal mode_str = mode_to_str(mode) batch_size = params[f"{mode_str}_batch_size"] n_channels = params.get("input_channels", 3) model = DiscreteVAE(num_tokens=params["num_tokens"], dim=params["n_embd"], hidden_dim=params["hidden_dim"], input_channels=n_channels, convblocks=params.get("convblocks", [(3, 64), (3, 128), (3, 256)]), recompute_grad=params.get("recompute_grad", False), use_bf16=params.get("use_bf16", False), stack_factor=params.get("stack_factor", 1), dimensions=H) if mode == tf.estimator.ModeKeys.PREDICT: raise NotImplementedError train_gumbel = params.get("train_gumbel_hard", True) eval_gumbel = params.get("eval_gumbel_hard", True) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) gumbel = train_gumbel if mode == tf.estimator.ModeKeys.TRAIN else eval_gumbel if params.get("temp_anneal_steps", None): warmup_frac = tf.cast(tf.train.get_global_step(), tf.float32) / params["temp_anneal_steps"] warmup_frac = tf.minimum(warmup_frac, tf.constant(1.0)) temp = params["temp_start"] - warmup_frac * (params["temp_start"] - params["temp"]) else: temp = params.get("temp", 1.0) # TODO: add back in microbatching if params.get("use_bf16", False): with tf.tpu.bfloat16_scope(): with tf.variable_scope("vae"): loss, reconstruction = model.forward(features, return_recon_loss=True, temperature=temp, hard_gumbel=gumbel) loss = tf.cast(loss, tf.float32) reconstruction = tf.cast(reconstruction, tf.float32) else: with tf.variable_scope("vae"): loss, reconstruction = model.forward(features, return_recon_loss=True, temperature=temp, hard_gumbel=gumbel) optimizer = tf.train.AdamOptimizer(learning_rate=params["lr"]) optimizer = tf.tpu.CrossShardOptimizer(optimizer) global_step = tf.train.get_or_create_global_step() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) def host_call_fn(gs, loss, input, reconstruction): gs = gs[0] loss = tf.math.reduce_mean(loss) denormalize = lambda x: (x + 1) / 2 with tf2.summary.create_file_writer(params['model_path']).as_default(): tf2.summary.scalar('loss', loss, step=gs) tf2.summary.image('input_image', denormalize(input), step=gs) tf2.summary.image('reconstruction_image', denormalize(reconstruction), step=gs) return tf.summary.all_v2_summary_ops() def metric_fn(gs, loss, input, reconstruction): gs = gs[0] loss = tf.math.reduce_mean(loss) denormalize = lambda x: (x + 1) / 2 with tf2.summary.create_file_writer(params['model_path']).as_default(): loss_op = tf.metrics.mean(loss) with tf2.summary.record_if(loss_op[0] < tf.constant(1e-9)): tf2.summary.image('eval/input_image', denormalize(input), step=gs) tf2.summary.image('eval/reconstruction_image', denormalize(reconstruction), step=gs) with tf.control_dependencies(tf.summary.all_v2_summary_ops()): dummy_op = tf.no_op() return {"_loss": loss_op, "zzz_dummy": (tf.constant(0), dummy_op)} # To log the loss, current learning rate, and epoch for Tensorboard, the # summary op needs to be run on the host CPU via host_call. host_call # expects [batch_size, ...] Tensors, thus reshape to introduce a batch # dimension. These Tensors are implicitly concatenated to # [params['batch_size']]. gs_t = tf.reshape(global_step, [1]) loss_t = tf.reshape(loss, [1]) host_call = (host_call_fn, [gs_t, loss_t, features, reconstruction]) metric = (metric_fn, [gs_t, loss_t, features, reconstruction]) return tpu_estimator.TPUEstimatorSpec( mode, loss=loss, host_call=host_call if mode == tf.estimator.ModeKeys.TRAIN else None, train_op=train_op, eval_metrics=metric)
def dalle_model_fn(features, labels, mode, params): # since we can simply infer labels here based on the input - features here are the text input, # and labels are the image input global_step = tf.train.get_global_step() # Get global step mode_str = mode_to_str(mode) # load vae in tensorflow graph before mtf vae, vae_checkpoint_path = load_vae_model(params, mode_str) initialize_vae_weights(vae_checkpoint_path) H = W = params["dataset"]["image_size"] image_seq_len = (vae.H // (2**len(vae.convblocks)))**2 // ( vae.stack_factor**2) # TODO: check this is correct batch_size = params[f"{mode_str}_batch_size"] n_channels = params.get("input_channels", 3) with tf.variable_scope("vae"): vae_logits = vae.forward(features, return_logits=True) # TODO: using argmax sampling for now, but is that optimal? tokens = tf.math.argmax(vae_logits, -1) img_tokens_reshaped = tf.cast( tf.reshape(tokens, (batch_size, image_seq_len)), tf.int32) # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) model = DALLE( n_embd=params["n_embd"], text_vocab_size=params["text_vocab_size"], image_vocab_size=params["image_vocab_size"], text_seq_len=params["text_seq_len"], image_seq_len=image_seq_len, n_layers=params["n_layers"], n_heads=params["n_heads"], batch_size=batch_size, bf_16=params["bf_16"], mode=mode_str, params=params, ) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"image_inputs": features, "text_inputs": labels} mtf_features = {} for key, x in features_dict.items(): if x is not None: if key == "text_inputs": text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) x = tf.concat( (text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) mtf_shape = mtf.Shape([ model.dimensions["batch_dim"], model.dimensions["total_seq_dim"] ]) mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if key == "image_inputs": mtf_shape = mtf.Shape([ model.dimensions["batch_dim"], mtf.Dimension("img_height_dim", vae.H), mtf.Dimension("img_width_dim", vae.W), mtf.Dimension("img_channel_dim", vae.num_ch), ]) x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC mtf_features["image_inputs"] = mtf.import_fully_replicated( mesh, x, mtf_shape, name=key) scalar_summary("input_image", mtf_features["image_inputs"]) if mode == tf.estimator.ModeKeys.PREDICT: raise NotImplementedError # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=model.dimensions["batch_dim"], sequence_length=model.total_seq_dim, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): loss, loss_batch = model.forward(mtf_features, return_loss=True) return {"loss": loss, "loss_batch": loss_batch} # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, model.dimensions["batch_dim"], num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] else: loss, loss_batch = model.forward(mtf_features, return_loss=True) del loss_batch # TODO: may need this for some metrics - otherwise, remove from output if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if num_microbatches > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=model.variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=model.variable_dtype) # Log summaries to tensorboard scalar_summary("loss", loss) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=params.get( "max_checkpoints", 5), keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=None)
def __call__(self, features, labels, mode, params): # this is the model_fn """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() # Graph setup graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(self.mesh_shape) layout_rules = mtf.convert_to_layout_rules(self.layout) if params["use_tpu"]: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * num_hosts devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memory_usage) mesh = mtf.Mesh(graph, "my_mesh", var_placer) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, devices_memory_usage) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, "my_mesh", var_placer) # RUN Model with mtf.utils.outside_all_rewrites(): logits, loss = self.model(mesh, features, params) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if self.optimizer == "Adafactor": optimizer = mtf.optimize.AdafactorOptimizer() else: assert self.optimizer == "SGD" optimizer = mtf.optimize.SgdOptimizer( learning_rate=self.learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) # covert back to tensorflow format lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) # create estimator with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True, ) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( self.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener], ) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], ) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {"mean_logits": mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics, ) elif mode == tf.estimator.ModeKeys.PREDICT: return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.PREDICT, evaluation_hooks=[restore_hook], loss=None, eval_metrics=eval_metrics, ) @property def dense_initializer(self): if self.config.initializer_range: return tf.truncated_normal_initializer( stddev=self.config.initializer_range) else: return mtf.layers.VarianceScalingInitializer(scale=0.4) @property def embedding_initializer(self): initializer = self.dense_initializer if isinstance(initializer, mtf.layers.DenseInitializer): # embedding matrix is also used as classifier weight matrix. # scale it appropriately. return initializer(reduced_dims=[self.model_dim], new_dims=[self.vocab_dim]) else: return initializer @property def num_hidden_layers(self): return self.config.num_hidden_layers def normalize(self, x, reduce_dim): return nn.layer_norm( x, reduce_dim, subtract_mean=self.config.use_bias, use_bias=self.config.use_bias, ) def model(self, mesh, x, y, params): # x :: [batch, io, vocab] if params["precision"] == "bfloat16": dtype = tf.bfloat16 # master has type float32, slice and activation have type bfloat16 variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16, tf.bfloat16) else: dtype = tf.float32 # master, slice and activate have all float16 variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) # Build the actual model batch_dim = mtf.Dimension("batch", params["batch_size"]) vocab_dim = mtf.Dimension("vocab", params["vocab_size"]) io_dim = mtf.Dimension("sequence", params["io"]) io_chan_dim = mtf.Dimension("io", params["io_channels"]) # from input to mtf x = mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, io_dim, vocab_dim])) # Embeddings with tf.variable_scope(scope="toy", default_name="seq2seq"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([vocab_dim, io_chan_dim]), initializer=self.embedding_initializer, ) word_embedding_output = mtf.gather( embedding_table, x, dim=vocab_dim, output_shape=io_chan_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. embedding_output = word_embedding_output pos_embedding = mtf.get_variable( mesh, "pos_embeddings", mtf.Shape([io_dim, io_chan_dim]), initializer=self.embedding_initializer, ) embedding_output = self.normalize(embedding_output) embedding_output = mtf.dropout( embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob, ) # shift token by pos embeddings x = word_embedding_output + pos_embedding x = mtf.cast(x, variable_dtype.activation_dtype) h = x for lnum in range(1, self.num_hidden_layers + 2): if lnum + 1 == self.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension("hidden_even", io_chan_dim) else: dim = mtf.Dimension("hidden_odd", io_chan_dim) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, name="layer_%d" % lnum, ) prediction = h # project back to token dimensions # compute the mean quare loss between the input and the output loss = mtf.reduce_mean(mtf.square(y - prediction)) return prediction, loss