def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu): mtf_samples = mtf.anonymize(self.sample(features, mesh)) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_samples) if self.has_input: ndims = len(outputs.shape.as_list()) actual_batch_size = tf.shape(features["inputs"])[0] outputs = tf.slice( outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1)) predictions = { "outputs": outputs } if features.get("infer_targets") is not None: predictions["infer_targets"] = features["infer_targets"] if features.get("inputs") is not None: predictions["inputs"] = features["inputs"] if use_tpu: t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)])
def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu): mtf_samples = mtf.anonymize(self.sample(features, mesh)) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_samples) if self.has_input: ndims = len(outputs.shape.as_list()) actual_batch_size = tf.shape(features["inputs"])[0] outputs = tf.slice(outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1)) predictions = {"outputs": outputs} if features.get("infer_targets") is not None: predictions["infer_targets"] = features["infer_targets"] if features.get("inputs") is not None: predictions["inputs"] = features["inputs"] if use_tpu: t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)])
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.use_tpu: ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list,) # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients([loss], [v.outputs[0] for v in graph.trainable_variables]) if FLAGS.optimizer == 'Adafactor': optimizer = mtf.optimize.AdafactorOptimizer() else: assert FLAGS.optimizer == 'SGD' optimizer = mtf.optimize.SgdOptimizer(lr=FLAGS.lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {'mean_logits': mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"inputs": features, "labels": labels} sequence_length_dict = { "inputs": params["n_ctx"], "labels": params["n_ctx"] } params = add_mode_to_params(params, mode) batch_size = get_batch_size(params) batch_dim = mtf.Dimension("batch", batch_size) batch_dims = [batch_dim] feature_length = sequence_length_dict["inputs"] length_dim = mtf.Dimension("sequence", feature_length) mtf_features = {} for key, x in features_dict.items(): if x is not None: feature_shape = mtf.Shape(batch_dims + [length_dim]) if type(features_dict[key]) == dict: features_dict[key] = features_dict[key]["feature"] x = tf.cast(features_dict[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model other_features = {} memory_length_dim = mtf.Dimension("memory_length", length_dim.size) attn_bias = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None # Add attn_bias into mtf_features other_features["attn_bias"] = attn_bias # Define other Dimensions that we'll need inside the model embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) # We need this because gathering when both the args have the same dimension in them breaks things # This dim is specifically for the weights # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["inputs"] if params["remove_partial_sequences"] is None: params["remove_partial_sequences"] = False export = params.get("export", False) if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax']) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): mtf_samples, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=batch_dim, sequence_length=sequence_length_dict, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) return { "logits": logits, "loss": loss, "loss_batch": loss_batch } else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] logits = output_dict["logits"] else: # If we're not splitting into microbatches, return logits & loss as is if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Auto layout generation if params["auto_layout"]: auto_layout(graph, mesh_shape, logits, loss) if params["auto_layout_and_mesh_shape"]: auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype) # Log summaries to tensorboard mtf.scalar_summary("loss", loss) # Log gradients if in params if params["log_grads"] not in [None, False]: for g in var_grads: grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) else: # For now, we can only export fully-replicated tensors. # This has to be done before lowering or they will not be included in the graph mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) max_logits = mtf.argmax(logits, vocab_dim) del logits fully_replicated_mean_logits = mtf.anonymize(mean_logits) fully_replicated_max_logits = mtf.anonymize(max_logits) fully_replicated_loss_batch = mtf.anonymize(loss_batch) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step tf.logging.info(f"tf_update_ops: {tf_update_ops}") train_op = tf.group(tf_update_ops) else: tf_mean_logits = lowering.export_to_tf_tensor( fully_replicated_mean_logits) tf_max_logits = lowering.export_to_tf_tensor( fully_replicated_max_logits) tf_loss_batch = tf.to_float( lowering.export_to_tf_tensor(fully_replicated_loss_batch)) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # Evaluation metrics def _perplexity(loss): perplexity = tf.exp(loss) return tf.metrics.mean(perplexity) def _bits_per_byte(loss): bpb = loss * (0.29335 / math.log(2)) return tf.metrics.mean(bpb) def _metric_fn(tf_mean_logits, tf_loss_batch): mean_logits = tf.metrics.mean(tf_mean_logits) loss = tf.reduce_mean(tf_loss_batch) perp = _perplexity(loss) bpb = _bits_per_byte(loss) return { "mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb } def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): eos_token = params["eos_id"] answer_positions = tf.where( tf.math.not_equal(labels, eos_token)) correct_answers = tf.gather_nd( tf.math.equal(tf_max_logits, labels), answer_positions) accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) # I guess tf_loss_batch has z_loss and maybe other stuff added to it # so maybe this should be calculated separately in the future answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) log_perplexity = tf.metrics.mean(answer_loss) return { "lambada_acc": accuracy, "lambada_log_ppl": log_perplexity } eval_task = params["eval_task"] if eval_task == "lambada": eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) else: eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: dictionary where keys are strings like "inputs" and "targets" and the values are the actual values of "inputs". See TPUEstimator's docs for more information labels: ignored argument mode: a tf.estimator.ModeKeys params: dictionary containing the key "context" config: ignored argument Returns: a TPUEstimatorSpec """ del labels, config global_step = tf.train.get_global_step() if use_tpu and "context" in params: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) # deprecated mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None # deprecated mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) mtf_features = {} for key, x in features.items(): outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) # Some auxiliary features may have been generated in packing. # The names of these new features are of the form # "<original_feature_name>_<suffix>", e.g. "inputs_segmentation". # We look up the lengths based on the original feature name, without # the "_<suffix>". feature_length = sequence_length[key.split("_")[0]] length_dim = mtf.Dimension("length", feature_length) ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)] if ensemble_inputs else []) feature_shape = mtf.Shape(ensemble_dims + [outer_batch_dim, batch_dim, length_dim]) x = tf.cast(features[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) if not use_tpu: tf.logging.info("feature %s : %s" % (key, x)) x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=10) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if key == "targets" or key == "codeprefixedtargets" or key == "controlcode": anon_targets = mtf.anonymize(mtf_features[key]) if mode == tf.estimator.ModeKeys.PREDICT: def _feature_shape(key): feature_length = sequence_length[key.split("_")[0]] return mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", feature_length) ]) mtf_features = { k: mtf.reshape(v, _feature_shape(k)) for k, v in six.iteritems(mtf_features) } inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None if predict_fn: mtf_samples = predict_fn(model=transformer_model, features=mtf_features, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Unitransformer): # pad so that there is enough room for the targets inputs = mtf.pad(inputs, [0, sequence_length["targets"]], length_dim.name) mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype(), remove_partial_sequences=True) elif isinstance(transformer_model, Bitransformer_ll): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} # When exporting a model, we need to communicate to TF-Serving that # master variables need to be copied to their slave slice variables. # Estimator uses a Scaffold's "local_init_op" for this purpose, so we # augment the default "local_init_op" here. # # The "ready_op" is also constructed here to ensure the variables # initialized by "local_init_op" are the same ones checked by "ready_op". # # WARNING: Any variables created outside of this model_fn() # (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor # checked by these ops. def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) def logits_and_loss(mtf_features): """Compute logits and loss. Args: mtf_features: a dictionary Returns: logits: a mtf.Tensor loss: a mtf.Tensor """ if model_type == "lm": # TOTRY Adapt that to our case if "inputs" in mtf_features: mtf_features = _dynamic_text2self(mtf_features) _, _, length_dim = mtf_features["targets"].shape inputs = mtf.shift(mtf_features["targets"], offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if control_codes: codeprefixedtargets = mtf_features["codeprefixedtargets"] else: codeprefixedtargets = None if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": if control_codes: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "codeprefixedtargets_segmentation", None), decoder_subsequence_id=mtf_features.get( "codeprefixedtargets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "codeprefixedtargets_position", None), ) else: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), decoder_subsequence_id=mtf_features.get( "targets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") if isinstance(transformer_model, Bitransformer_ll): if cycle_consistency_loss: logits_ae, l_ae = transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None with gin.config_scope('training'): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # mtf_samples = mtf.anonymize(mtf_samples) outputs = mtf_samples logits_cycle, l_cycle = transformer_model.call_simple( inputs=outputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle return logits_cycle, loss_ae_cycle else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), num_microbatches=num_microbatches, **position_kwargs) if mode == tf.estimator.ModeKeys.TRAIN: num_microbatches = serialize_num_microbatches( batch_dim, sequence_length, mesh_shape, layout_rules) if num_microbatches > 1: def serialized_fn(mtf_features): return { "loss": (logits_and_loss(mtf_features)[1] / num_microbatches) } var_grads, loss_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = loss_dict["loss"] else: loss = logits_and_loss(mtf_features)[1] var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if tpu_summaries: mtf.scalar_summary("loss", loss) if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule if isinstance(variable_filter, str): pattern = re.compile(variable_filter) variable_filter_fn = lambda v: pattern.search(v.name) elif variable_filter is None: variable_filter_fn = lambda v: True elif callable(variable_filter): variable_filter_fn = variable_filter else: raise ValueError( "variable_filter must be None, a string, or a callable function" ) trainable_vars = [ v for v in graph.trainable_variables if variable_filter_fn(v) ] trainable_var_grads = [ g for g, v in zip(var_grads, graph.trainable_variables) if variable_filter_fn(v) ] if len(trainable_vars) != len(graph.trainable_variables): tf.logging.info("Variables being trained:") tf.logging.info([v.name for v in trainable_vars]) tf.logging.info("Variables not being trained:") tf.logging.info([ v.name for v in graph.trainable_variables if not variable_filter_fn(v) ]) update_ops = optimizer(learning_rate=learning_rate).apply_grads( trainable_var_grads, trainable_vars) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() if tpu_summaries: # has to be outside of # with mtf.utils.outside_all_rewrites() host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None with mtf.utils.outside_all_rewrites(): if init_checkpoint: ckpt_vars = { v for v, _ in tf.train.list_variables(init_checkpoint) } global_vars = {v.op.name for v in tf.global_variables()} restore_vars = ckpt_vars.intersection(global_vars) tf.logging.info("Initializing variables from %s:", init_checkpoint) tf.logging.debug("\n".join(sorted(restore_vars))) tf.logging.info("Variables in %s but not in graph:", init_checkpoint) tf.logging.info("\n".join(sorted(ckpt_vars - global_vars))) tf.logging.info("Variables in graph but not in %s:", init_checkpoint) tf.logging.info("\n".join(sorted(global_vars - ckpt_vars))) tf.train.init_from_checkpoint(init_checkpoint, {v: v for v in restore_vars}) # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True, include_step_in_filename=False) if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) elif mode == tf.estimator.ModeKeys.EVAL: logits, loss = logits_and_loss(mtf_features) anon_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) tf_loss = tf.cast(tf_loss, tf.float32) tf_logits = tf.cast(lowering.export_to_tf_tensor(anon_logits), tf.float32) def simple_metrics(logits, labels): """Simple metrics for teacher-forced eval.""" weights = tf.cast(tf.not_equal(labels, 0), tf.float32) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) predictions = tf.cast(tf.argmax(logits, axis=-1), labels.dtype) token_correct = tf.cast(tf.equal(predictions, labels), tf.float32) * weights sequence_correct = tf.to_float( tf.equal(tf.reduce_sum(token_correct, -1), tf.reduce_sum(weights, -1))) sequence_weights = tf.to_float( tf.not_equal(tf.reduce_sum(weights, -1), 0)) return { "neg_log_perplexity": tf.metrics.mean(-xent, weights), "token_accuracy": tf.metrics.mean(token_correct, weights), "sequence_accuracy": tf.metrics.mean(sequence_correct, sequence_weights) } labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (simple_metrics, [tf_logits, labels]) with mtf.utils.outside_all_rewrites(): restore_hook = mtf.MtfRestoreHook(lowering) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config use_tpu = FLAGS.tpu global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) model = transformer.Bitransformer( encoder_layer_stack=layer_stack(include_encdec_attention=False), decoder_layer_stack=layer_stack(include_encdec_attention=True), encoder_d_model=FLAGS.d_model, decoder_d_model=FLAGS.d_model, input_vocab_size=transformer_dataset.padded_vocab_size( transformer_dataset.inputs_vocab_size(FLAGS.dataset)), output_vocab_size=transformer_dataset.padded_vocab_size( transformer_dataset.targets_vocab_size(FLAGS.dataset)), max_length=FLAGS.max_length, shared_embedding=False, shared_embedding_and_softmax_weights=True, label_smoothing=FLAGS.label_smoothing, layout=FLAGS.layout, mesh_shape=FLAGS.mesh_shape) inputs = import_feature(features, mesh, "inputs") # Data-types used for variables and activations # See comments in the FLAGS master_dtype = tf.as_dtype(FLAGS.master_dtype) if FLAGS.slice_dtype: slice_dtype = tf.as_dtype(FLAGS.slice_dtype) elif not FLAGS.tpu or FLAGS.mode == "train": slice_dtype = tf.float32 else: slice_dtype = tf.bfloat16 if FLAGS.activation_dtype: activation_dtype = tf.as_dtype(FLAGS.activation_dtype) else: activation_dtype = tf.bfloat16 if FLAGS.tpu else tf.float32 variable_dtype = mtf.VariableDType(master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=activation_dtype) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: mtf_samples = model.decode( inputs, variable_dtype=variable_dtype, beam_size=FLAGS.beam_size, alpha=FLAGS.alpha, temperature=FLAGS.temperature) mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = { "outputs": outputs } return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = import_feature(features, mesh, "targets") anon_targets = mtf.anonymize(targets) logits, loss = model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=variable_dtype, encoder_sequence_id=import_feature(features, mesh, "inputs_segmentation"), decoder_sequence_id=import_feature( features, mesh, "targets_segmentation"), encoder_position=import_feature(features, mesh, "inputs_position"), decoder_position=import_feature(features, mesh, "targets_position") ) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def padded_neg_log_perplexity(logits, labels): weights = tf.to_float(tf.not_equal(labels, 0)) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) return {"neg_log_perplexity": tf.metrics.mean(-xent, weights)} labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list, ) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] is_real_example = None if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) num_labels_dim = mtf.Dimension("seq", num_labels) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (total_loss, per_example_loss, logits, probabilities) = create_model(bert_config, is_training, mtf_input_ids, mtf_input_mask, mtf_segment_ids, mtf_label_ids, num_labels_dim, layout_rules, mesh_shape) total_loss = mtf.anonymize(total_loss) per_example_loss = mtf.anonymize(per_example_loss) logits = mtf.anonymize(logits) if mode == tf.estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, max_optimized_variable_size=FLAGS.max_optimized_variable_size, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, } eval_metrics = (metric_fn, [ lowering.export_to_tf_tensor(per_example_loss), label_ids, lowering.export_to_tf_tensor(logits), is_real_example ]) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = bert_lib.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( mode, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.tpu.TPUEstimatorSpec( mode, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: return tf.estimator.tpu.TPUEstimatorSpec( mode, prediction_hooks=[restore_hook], predictions={ "probabilities": lowering.export_to_tf_tensor(probabilities) }, scaffold_fn=scaffold_fn)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None): hparams = copy.deepcopy(hparams) use_tpu = params and params.get("use_tpu", False) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def _model_fn(input_fea, input_lab): """Creates a model, add summary, modes (train or eval), and hooks.""" def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step): """Add all summaries.""" for k in scalars.keys(): if not isinstance(scalars[k], tf.Tensor): scalars[k] = tf.cast( lowering.export_to_tf_tensor(scalars[k]), tf.float32) def _host_loss_summary(global_step, tf_loss, **scalars): """Add summary.scalar in host side.""" gs = tf.cast(global_step, tf.int64) sum_loss = tf.contrib.summary.scalar( '{}_loss'.format(train_or_eval), tf_loss, step=gs) sum_ops = [sum_loss.op] for description, tf_metric in scalars.iteritems(): sum_metric = tf.contrib.summary.scalar('{}_{}'.format( train_or_eval, description), tf_metric, step=gs) sum_ops.append(sum_metric) with tf.control_dependencies(sum_ops): return tf.identity(tf_loss) # Cast the global step to tf.int32, since # outside_compilation does not support tf.int64. tf_loss = tpu.outside_compilation(_host_loss_summary, tf.cast(global_step, tf.int32), tf_loss, **scalars) return tf_loss global_step = tf.train.get_or_create_global_step() graph = mtf.Graph() # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * num_hosts devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) tf.logging.info('cpu_devices: {}, devices_mem: {}'.format( cpu_devices, devices_memory_usage)) var_placer = mtf.utils.BalancedVariablePlacer(cpu_devices, devices_memory_usage) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = unet.get_layout() mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None, d_assignment) with mtf.utils.outside_all_rewrites(): # Do not tpu_rewrite this part. preds, loss, scalars, bn_update_ops = ( unet.unet_with_spatial_partition(mesh, train_or_eval, input_fea, input_lab)) if train_or_eval == 'train': var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = FLAGS.lr * tf.pow( FLAGS.lr_drop_rate, tf.floor( tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps)) scalars['learning_rate'] = lr optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf_update_ops.extend( [lowering.lowered_operation(op) for op in bn_update_ops]) tf_update_ops_group = tf.group(tf_update_ops) else: # train_or_eval == 'eval': preds = [mtf.anonymize(pred) for pred in preds] # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_preds = [ tf.cast(lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds ] tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) if FLAGS.write_summary: tf_loss = _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step) master_to_slice_hook = mtf.MtfRestoreHook(lowering) if train_or_eval == 'train': with mtf.utils.outside_all_rewrites(): saver = tf.train.Saver(tf.global_variables(), save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) slice_to_master_hook = tf.train.CheckpointSaverHook( FLAGS.checkpoint_dir, save_steps=FLAGS.save_checkpoints_steps, saver=saver, listeners=[saver_listener]) captured_hooks.capture( [master_to_slice_hook, slice_to_master_hook]) return tf_update_ops_group else: # train_or_eval == 'eval': tf_preds.extend([tf_loss, global_step]) tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds] tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds] captured_hooks.capture([master_to_slice_hook, None]) captured_output_dtypes_shapes.capture( [tf_preds_dtypes, tf_preds_shapes]) return tpu_ops.outfeed_enqueue_tuple(tf_preds)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list,) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) unique_ids = features["unique_ids"] input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (start_logits, end_logits) = create_model( bert_config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, segment_ids=mtf_segment_ids) if mode == tf.estimator.ModeKeys.TRAIN: def compute_loss(logits, positions): one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim) log_probs = mtf.log_softmax(logits, seq_dim) loss = -mtf.reduce_mean( mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim)) return loss start_positions = features["start_positions"] mtf_start_positions = mtf.import_tf_tensor(mesh, start_positions, [batch_dim]) end_positions = features["end_positions"] mtf_end_positions = mtf.import_tf_tensor(mesh, end_positions, [batch_dim]) start_loss = compute_loss(start_logits, mtf_start_positions) end_loss = compute_loss(end_logits, mtf_end_positions) total_loss = (start_loss + end_loss) / 2.0 _, update_ops = optimization_lib.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, max_optimized_variable_size=FLAGS.max_optimized_variable_size, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) elif mode == tf.estimator.ModeKeys.PREDICT: start_logits = mtf.anonymize(start_logits) end_logits = mtf.anonymize(end_logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) if mode == tf.estimator.ModeKeys.TRAIN: tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) global_step = tf.train.get_global_step() tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = bert_lib.get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( mode, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: predictions = { "unique_ids": unique_ids, "start_logits": lowering.export_to_tf_tensor(start_logits), "end_logits": lowering.export_to_tf_tensor(end_logits), } return tf.estimator.tpu.TPUEstimatorSpec( mode, prediction_hooks=[restore_hook], predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode))
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ batch_dim = mtf.Dimension("batch", batch_size) length_dim = mtf.Dimension("length", length) mtf_shape = mtf.Shape([batch_dim, length_dim]) if key not in features: return None x = tf.to_int32(features[key]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") if text2self: mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=variable_dtype, temperature=temperature) else: mtf_samples = transformer_model.decode( inputs, variable_dtype=variable_dtype, beam_size=beam_size, alpha=alpha, temperature=temperature) mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if text2self: _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if text2self: position_kwargs = dict( sequence_id=_import_feature("targets_segmentation"), position=_import_feature("targets_position"), ) else: position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation"), decoder_sequence_id=_import_feature("targets_segmentation"), encoder_position=_import_feature("inputs_position"), decoder_position=_import_feature("targets_position"), ) logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=variable_dtype, **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def padded_neg_log_perplexity(logits, labels): weights = tf.to_float(tf.not_equal(labels, 0)) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) return { "neg_log_perplexity": tf.metrics.mean(-xent, weights) } labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() # mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) # layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.mode == "auto_parallel": mesh_shape_map = { 1: [("processor_rows", 1)], 2: [("processor_rows", 2)], 4: [("processor_rows", 2), ("processor_cols", 2)], 8: [("processor_rows", 2), ("processor_cols", 4)] } elif FLAGS.mode == "data_parallel": mesh_shape_map = { 1: [("processor_rows", 1)], 2: [("processor_rows", 2)], 4: [("processor_rows", 4)], 8: [("processor_rows", 8)] } else: raise ValueError mesh_shape = mesh_shape_map[FLAGS.gpu_num] devices = [f"gpu:{i}" for i in range(FLAGS.gpu_num)] var_placer = None mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) max_predictions_per_seq = masked_lm_positions.get_shape()[1].value max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq", max_predictions_per_seq) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_masked_lm_positions = mtf.import_tf_tensor( mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_ids = mtf.import_tf_tensor( mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_weights = mtf.import_tf_tensor( mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim]) mtf_next_sentence_labels = mtf.import_tf_tensor( mesh, next_sentence_labels, [batch_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = bert_lib.BertModel(config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, token_type_ids=mtf_segment_ids, mesh_shape=mesh_shape) (masked_lm_loss, masked_lm_example_loss, masked_lm_logits) = model.get_masked_lm_output( mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_logits ) = model.get_next_sentence_output(mtf_next_sentence_labels) extra_loss = model.get_extra_loss() total_loss = masked_lm_loss + next_sentence_loss total_loss = mtf.anonymize(total_loss) masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss) masked_lm_logits = mtf.anonymize(masked_lm_logits) next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss) next_sentence_logits = mtf.anonymize(next_sentence_logits) outputs = [total_loss] if FLAGS.mode == "auto_parallel": layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, outputs) elif FLAGS.mode == "data_parallel": layout_rules = [('batch', 'processor_rows')] else: raise ValueError variables = graph._all_variables for v in variables: tf.logging.info( "[parameter] (name,shape,dtype): ({},{},{})".format( v.name, v.shape, v.dtype.master_dtype)) tf.logging.info("layout rules: {}".format(layout_rules)) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, devices) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss + extra_loss, learning_rate, num_train_steps, num_warmup_steps, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) # saver_listener = mtf.MtfCheckpointSaverListener(lowering) # saver_hook = tf.train.CheckpointSaverHook( # FLAGS.output_dir, # save_steps=1000, # saver=saver, # listeners=[saver_listener]) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook])
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) mtf_features = {} for key, x in features.items(): x = tf.to_int32(features[key]) x = tf.reshape(x, [ outer_batch_size, batch_size // outer_batch_size, sequence_length ]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = mtf_features["inputs"] inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) elif mode == tf.estimator.ModeKeys.EVAL: raise NotImplementedError("We don't expect to use mode == eval.") else: assert mode == tf.estimator.ModeKeys.TRAIN num_microbatches = serialize_num_microbatches( batch_dim, length_dim, mesh_shape, layout_rules) def model_fn(mtf_features): """The kind of function we need for mtf.serialize_training_step. Args: mtf_features: a dictionary Returns: a dictionary """ targets = mtf_features["targets"] if model_type == "lm": _, _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if num_microbatches > 1: loss /= float(num_microbatches) del logits return {"loss": loss} if num_microbatches > 1: var_grads, loss_dict = mtf.serialize_training_step( mtf_features, model_fn, batch_dim, num_microbatches) else: loss_dict = model_fn(mtf_features) var_grads = mtf.gradients( [loss_dict["loss"]], [v.outputs[0] for v in graph.trainable_variables]) loss = loss_dict["loss"] if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule update_ops = optimizer(learning_rate=learning_rate).apply_grads( var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if use_tpu: if tpu_summaries: tf.summary.scalar("loss", tf_loss) host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning("Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls( hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout), mesh_devices, params['context'].device_assignment) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logitss = tf.metrics.mean(tf_logits) return {'mean_logitss': mean_logitss} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def _model_fn(input_fea, input_lab): """Creates a model, add summary, modes (train or eval), and hooks.""" # input_fea and input_lab should be a list (laid_out_tensors). if not isinstance(input_fea, list): input_fea = [input_fea] if not isinstance(input_lab, list): input_lab = [input_lab] def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step): """Add all summaries.""" for k in scalars.keys(): if not isinstance(scalars[k], tf.Tensor): scalars[k] = tf.cast( lowering.export_to_tf_tensor(scalars[k]), tf.float32) def _host_loss_summary(global_step, tf_loss, **scalars): """Add summary.scalar in host side.""" gs = tf.cast(global_step, tf.int64) sum_loss = contrib_summary.scalar( '{}_loss'.format(train_or_eval), tf_loss, step=gs) sum_ops = [sum_loss.op] for description, tf_metric in scalars.iteritems(): sum_metric = contrib_summary.scalar( '{}_{}'.format(train_or_eval, description), tf_metric, step=gs) sum_ops.append(sum_metric) with tf.control_dependencies(sum_ops): return tf.identity(tf_loss) if FLAGS.use_tpu: # Cast the global step to tf.int32, since # outside_compilation does not support tf.int64. tf_loss = tpu.outside_compilation( _host_loss_summary, tf.cast(global_step, tf.int32), tf_loss, **scalars) else: tf_loss = _host_loss_summary( tf.cast(global_step, tf.int32), tf_loss, **scalars) return tf_loss global_step = tf.train.get_or_create_global_step() graph, mesh, mesh_impl = mesh_context.create_graph_mesh_and_mesh_impl() with mtf.utils.outside_all_rewrites(): # Do not tpu_rewrite this part. Inside this unet, If you use Tensorflow, # instead of Mesh-Tensorflor, it will cause host to tpu send/rec. preds, loss, scalars, bn_update_ops = ( unet.unet_with_spatial_partition( mesh, mesh_impl, train_or_eval, input_fea, input_lab)) if train_or_eval == 'train': var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = FLAGS.lr * tf.pow( FLAGS.lr_drop_rate, tf.floor(tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps)) scalars['learning_rate'] = lr optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf_update_ops.extend( [lowering.lowered_operation(op) for op in bn_update_ops]) else: # train_or_eval == 'eval': preds = [mtf.anonymize(pred) for pred in preds] # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_preds = [tf.cast( lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds] tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) if FLAGS.write_summary: tf_loss = _add_summary( lowering, train_or_eval, tf_loss, scalars, global_step) master_to_slice_hook = mtf.MtfRestoreHook(lowering) if train_or_eval == 'train': with mtf.utils.outside_all_rewrites(): saver = tf.train.Saver(tf.global_variables(), save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) slice_to_master_hook = tf.train.CheckpointSaverHook( FLAGS.checkpoint_dir, save_steps=FLAGS.save_checkpoints_steps, saver=saver, listeners=[saver_listener]) captured_hooks.capture([master_to_slice_hook, slice_to_master_hook]) return tf.group([tf_loss] + tf_update_ops) else: # train_or_eval == 'eval': if FLAGS.use_tpu: tf_preds.extend([tf_loss, global_step]) tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds] tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds] captured_hooks.capture([master_to_slice_hook, None]) captured_output_dtypes_shapes.capture( [tf_preds_dtypes, tf_preds_shapes]) return tpu_ops.outfeed_enqueue_tuple(tf_preds) else: tf_preds.extend([tf_loss, global_step]) captured_hooks.capture([master_to_slice_hook, None]) return tf_preds
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key, allow_missing=False): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string allow_missing: a boolean Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) if key not in features: if allow_missing: return None else: raise ValueError("feature not found %s - features %s = " % (key, features)) tf.logging.info("Import feature %s: %s" % (key, features[key])) x = tf.to_int32(features[key]) x = tf.reshape( x, [outer_batch_size, batch_size // outer_batch_size, -1]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if model_type == "lm": _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if mode == tf.estimator.ModeKeys.EVAL: if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) labels = lowering.export_to_tf_tensor(anon_targets) restore_hook = mtf.MtfRestoreHook(lowering) # metric_names becomes locally scoped if we simply assign # ["padded_neg_log_perplexity"] to it conditioned on if it's None. local_metric_names = metric_names or ["token_accuracy"] def metric_fn(labels, outputs): return get_metric_fns(local_metric_names, labels, outputs) eval_metrics = (metric_fn, [labels, outputs]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, # Unfortunately TPUEstimatorSpec requires us to provide a value for # loss when in EVAL mode. Since we are sampling or decoding from the # model, we don't have a loss to report. loss=tf.constant(0.), evaluation_hooks=[restore_hook], eval_metrics=eval_metrics) if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=_import_feature("targets_segmentation", True), position=_import_feature("targets_position", True), ) elif isinstance(transformer_model, transformer.Bitransformer): position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation", True), decoder_sequence_id=_import_feature("targets_segmentation", True), encoder_position=_import_feature("inputs_position", True), decoder_position=_import_feature("targets_position", True), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list, ) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list(ctx.device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) max_predictions_per_seq = masked_lm_positions.get_shape()[1].value max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq", max_predictions_per_seq) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_masked_lm_positions = mtf.import_tf_tensor( mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_ids = mtf.import_tf_tensor( mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_weights = mtf.import_tf_tensor( mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim]) mtf_next_sentence_labels = mtf.import_tf_tensor( mesh, next_sentence_labels, [batch_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = bert_lib.BertModel(config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, token_type_ids=mtf_segment_ids, layout=layout_rules, mesh_shape=mesh_shape) (masked_lm_loss, masked_lm_example_loss, masked_lm_logits) = model.get_masked_lm_output( mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_logits ) = model.get_next_sentence_output(mtf_next_sentence_labels) extra_loss = model.get_extra_loss() total_loss = masked_lm_loss + next_sentence_loss total_loss = mtf.anonymize(total_loss) masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss) masked_lm_logits = mtf.anonymize(masked_lm_logits) next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss) next_sentence_logits = mtf.anonymize(next_sentence_logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss + extra_loss, learning_rate, num_train_steps, num_warmup_steps, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_logits, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_logits = tf.reshape(masked_lm_logits, [-1, masked_lm_logits.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_logits, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_logits = tf.reshape( next_sentence_logits, [-1, next_sentence_logits.shape[-1]]) next_sentence_predictions = tf.argmax(next_sentence_logits, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss, } eval_metrics = (metric_fn, [ lowering.export_to_tf_tensor(masked_lm_example_loss), lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids, masked_lm_weights, lowering.export_to_tf_tensor(next_sentence_example_loss), lowering.export_to_tf_tensor(next_sentence_logits), next_sentence_labels ]) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def __call__(self, features, labels, mode, params): # this is the model_fn """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() # Graph setup graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(self.mesh_shape) layout_rules = mtf.convert_to_layout_rules(self.layout) if params["use_tpu"]: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * num_hosts devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memory_usage) mesh = mtf.Mesh(graph, "my_mesh", var_placer) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, devices_memory_usage) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, "my_mesh", var_placer) # RUN Model with mtf.utils.outside_all_rewrites(): logits, loss = self.model(mesh, features, params) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if self.optimizer == "Adafactor": optimizer = mtf.optimize.AdafactorOptimizer() else: assert self.optimizer == "SGD" optimizer = mtf.optimize.SgdOptimizer( learning_rate=self.learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) # covert back to tensorflow format lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) # create estimator with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True, ) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( self.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener], ) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], ) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {"mean_logits": mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics, ) elif mode == tf.estimator.ModeKeys.PREDICT: return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.PREDICT, evaluation_hooks=[restore_hook], loss=None, eval_metrics=eval_metrics, ) @property def dense_initializer(self): if self.config.initializer_range: return tf.truncated_normal_initializer( stddev=self.config.initializer_range) else: return mtf.layers.VarianceScalingInitializer(scale=0.4) @property def embedding_initializer(self): initializer = self.dense_initializer if isinstance(initializer, mtf.layers.DenseInitializer): # embedding matrix is also used as classifier weight matrix. # scale it appropriately. return initializer(reduced_dims=[self.model_dim], new_dims=[self.vocab_dim]) else: return initializer @property def num_hidden_layers(self): return self.config.num_hidden_layers def normalize(self, x, reduce_dim): return nn.layer_norm( x, reduce_dim, subtract_mean=self.config.use_bias, use_bias=self.config.use_bias, ) def model(self, mesh, x, y, params): # x :: [batch, io, vocab] if params["precision"] == "bfloat16": dtype = tf.bfloat16 # master has type float32, slice and activation have type bfloat16 variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16, tf.bfloat16) else: dtype = tf.float32 # master, slice and activate have all float16 variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) # Build the actual model batch_dim = mtf.Dimension("batch", params["batch_size"]) vocab_dim = mtf.Dimension("vocab", params["vocab_size"]) io_dim = mtf.Dimension("sequence", params["io"]) io_chan_dim = mtf.Dimension("io", params["io_channels"]) # from input to mtf x = mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, io_dim, vocab_dim])) # Embeddings with tf.variable_scope(scope="toy", default_name="seq2seq"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([vocab_dim, io_chan_dim]), initializer=self.embedding_initializer, ) word_embedding_output = mtf.gather( embedding_table, x, dim=vocab_dim, output_shape=io_chan_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. embedding_output = word_embedding_output pos_embedding = mtf.get_variable( mesh, "pos_embeddings", mtf.Shape([io_dim, io_chan_dim]), initializer=self.embedding_initializer, ) embedding_output = self.normalize(embedding_output) embedding_output = mtf.dropout( embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob, ) # shift token by pos embeddings x = word_embedding_output + pos_embedding x = mtf.cast(x, variable_dtype.activation_dtype) h = x for lnum in range(1, self.num_hidden_layers + 2): if lnum + 1 == self.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension("hidden_even", io_chan_dim) else: dim = mtf.Dimension("hidden_odd", io_chan_dim) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, name="layer_%d" % lnum, ) prediction = h # project back to token dimensions # compute the mean quare loss between the input and the output loss = mtf.reduce_mean(mtf.square(y - prediction)) return prediction, loss