def test_get_laidout_tensors(self, is_eval_mode): mesh_shape = "mesh_x:2, mesh_y:1" layout = "batch:mesh_x, io:mesh_y" batch_io_dim = 4 with tf.Session() as sess: topology, num_cores = self.initialize_system(sess) # Get a device_assignment object for mtf. d_assignment = device_assignment.device_assignment( topology, computation_shape=[1, 1, 1], num_replicas=num_cores) # Hacked dataset creator: creates different datasets for the first and # second call, in order to test SimdMeshImplInputReader. self.sub_batch_created_times = 0 def stateful_ds_creator(): whole_batch = tf.eye(batch_io_dim, dtype=tf.float32) sub_batch = tf.slice(whole_batch, [self.sub_batch_created_times * 2, 0], [2, 4]) self.sub_batch_created_times += 1 return tf.data.Dataset.from_tensors( sub_batch).repeat().unbatch() batch_dim = mtf.Dimension("batch", batch_io_dim) io_dim = mtf.Dimension("io", batch_io_dim) mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])] # Get mesh_impl. mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, None, d_assignment) simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_impl, stateful_ds_creator, mtf_input_shapes, external_worker=False, is_eval_mode=is_eval_mode) def model_fn(features): return features replicated_computation = tpu.replicate( computation=model_fn, inputs=[[]] * num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=d_assignment) simd_input_reader.start_infeed_thread(sess, 1) results = sess.run(replicated_computation) print("results: {}".format(results)) core_0_data = results[0][0] core_1_data = results[1][0] print("core_0_data: {}".format(core_0_data)) print("core_1_data: {}".format(core_1_data)) if is_eval_mode: # If there is only one dataset object, then the stateful_ds_creator() # should be called only once. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_1_data) else: # If there are two dataset objects, then the stateful_ds_creator() # should be called twice. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32), core_1_data) sess.run(tf.tpu.shutdown_system())
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list, ) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] is_real_example = None if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) num_labels_dim = mtf.Dimension("seq", num_labels) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (total_loss, per_example_loss, logits, probabilities) = create_model(bert_config, is_training, mtf_input_ids, mtf_input_mask, mtf_segment_ids, mtf_label_ids, num_labels_dim, layout_rules, mesh_shape) total_loss = mtf.anonymize(total_loss) per_example_loss = mtf.anonymize(per_example_loss) logits = mtf.anonymize(logits) if mode == tf.estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, max_optimized_variable_size=FLAGS.max_optimized_variable_size, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, } eval_metrics = (metric_fn, [ lowering.export_to_tf_tensor(per_example_loss), label_ids, lowering.export_to_tf_tensor(logits), is_real_example ]) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = bert_lib.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( mode, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.tpu.TPUEstimatorSpec( mode, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: return tf.estimator.tpu.TPUEstimatorSpec( mode, prediction_hooks=[restore_hook], predictions={ "probabilities": lowering.export_to_tf_tensor(probabilities) }, scaffold_fn=scaffold_fn)
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = cifar_model(features, labels, mesh) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_size = mesh_shape.size # To enable manual device placement (e.g.: GPU 1, 2) comment the line below and uncomment the next one mesh_devices = [""] * mesh_size # mesh_devices = ['GPU:' + str(i) for i in range(mesh_size)] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) labels = features['label'] if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) # Variables that affect learning rate. num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) tf.summary.scalar('learning_rate', lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=mtf_lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.use_tpu: ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if FLAGS.optimizer == 'Adafactor': optimizer = mtf.optimize.AdafactorOptimizer() else: assert FLAGS.optimizer == 'SGD' optimizer = mtf.optimize.SgdOptimizer(learning_rate=FLAGS.lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {'mean_logits': mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list,) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list(ctx.device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, [""] * mesh_shape.size) var_placer = None mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) max_predictions_per_seq = masked_lm_positions.get_shape()[1].value max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq", max_predictions_per_seq) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_masked_lm_positions = mtf.import_tf_tensor( mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_ids = mtf.import_tf_tensor( mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_weights = mtf.import_tf_tensor( mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim]) mtf_next_sentence_labels = mtf.import_tf_tensor( mesh, next_sentence_labels, [batch_dim]) is_training = (mode == tf_estimator.ModeKeys.TRAIN) model = bert_lib.BertModel( config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, token_type_ids=mtf_segment_ids, layout=layout_rules, mesh_shape=mesh_shape) (masked_lm_loss, masked_lm_example_loss, masked_lm_logits) = model.get_masked_lm_output( mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_logits) = model.get_next_sentence_output( mtf_next_sentence_labels) extra_loss = model.get_extra_loss() total_loss = masked_lm_loss + next_sentence_loss total_loss = mtf.anonymize(total_loss) masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss) masked_lm_logits = mtf.anonymize(masked_lm_logits) next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss) next_sentence_logits = mtf.anonymize(next_sentence_logits) # TRAIN mode if mode == tf_estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss + extra_loss, learning_rate, num_train_steps, num_warmup_steps, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf_estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) elif mode == tf_estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_logits, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_logits = tf.reshape(masked_lm_logits, [-1, masked_lm_logits.shape[-1]]) masked_lm_predictions = tf.argmax( masked_lm_logits, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_logits = tf.reshape( next_sentence_logits, [-1, next_sentence_logits.shape[-1]]) next_sentence_predictions = tf.argmax( next_sentence_logits, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss, } eval_metrics = (metric_fn, [ lowering.export_to_tf_tensor(masked_lm_example_loss), lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids, masked_lm_weights, lowering.export_to_tf_tensor(next_sentence_example_loss), lowering.export_to_tf_tensor(next_sentence_logits), next_sentence_labels ]) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf_estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf_estimator.tpu.TPUEstimatorSpec( tf_estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf_estimator.ModeKeys.EVAL: return tf_estimator.tpu.TPUEstimatorSpec( tf_estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def maybe_reshape_attention_input_for_2d_sharding(context, q, k, v, bias, unsplittable_dims): """Reshape the inputs to attention to split over an unused mesh dimension. In the case where the attention computation is unnecessarily replicated, this function reshapes the attention inputs to remove the unnecessary replication. This becomes relevent when doing 2-dimenional model parallelism. d_model is sharded over one mesh dimension and [vocab, num_heads, d_ff] are sharded over the other mesh dimension. This fully distributes all of the einsum operations, except for the internals of the attention computation. To distribute that computation, this function creates a new tensor-dimension from the low bits of either the batch dimension or the num_heads dimension, and then splits that dimension over the unused mesh dimension. Args: context: a transformer.Context q: a Tensor k: a Tensor v: a Tensor bias: a Tensor unsplittable_dims: a list of tensor-dimensions not to split. The key/value dimensions should be passed here. Returns: reshaped_q: a Tensor reshaped_k: a Tensor reshaped_v: a Tensor reshaped_bias: a Tensor """ original_inputs = q, k, v, bias # we need to know the layout and mesh-shape to figure out what to do. if not context or not context.model.layout or not context.model.mesh_shape: return original_inputs mesh_shape = mtf.convert_to_shape(context.model.mesh_shape) layout_rules = mtf.convert_to_layout_rules(context.model.layout) # find a mesh dim that is unused (no tensor-dimension is split across it) mesh_axis_used = [False] * mesh_shape.ndims for x in original_inputs: for mesh_axis in layout_rules.tensor_layout( x.shape, mesh_shape).tensor_axis_to_mesh_axis: if mesh_axis is not None: mesh_axis_used[mesh_axis] = True if False not in mesh_axis_used: return original_inputs mesh_dim = mesh_shape.dims[mesh_axis_used.index(False)] # Choose an appropriate name for the new tensor-dimension so that the layout # will know to split it across the unused mesh dimension. tensor_dim_name = None tensor_dim_name = layout_rules.mesh_dimension_name_to_tensor_dimension_names( mesh_dim.name) if tensor_dim_name: tensor_dim_name = tensor_dim_name[0] else: return original_inputs # Find a tensor-dimension that we can further split, by breaking off the # lower bits into our new tensor-dimension. # This resplittable tensor-dimension must be presnent in all of q, k, v # and must be large enough to be further split. resplittable_dim = None for d in q.shape.dims: if d in k.shape.dims and d in v.shape.dims and d not in unsplittable_dims: num_splits = mtf.tensor_dim_to_mesh_dim_size( context.model.layout, context.model.mesh_shape, d) if d.size % (num_splits * mesh_dim.size) == 0: resplittable_dim = d break if not resplittable_dim: return original_inputs new_dim_high = mtf.Dimension(resplittable_dim.name, num_splits) new_dim_low = mtf.Dimension(tensor_dim_name, resplittable_dim.size // num_splits) def _my_reshape(x): if x and resplittable_dim in x.shape.dims: return mtf.replace_dimensions(x, resplittable_dim, [new_dim_high, new_dim_low]) else: return x return _my_reshape(q), _my_reshape(k), _my_reshape(v), _my_reshape(bias)
def main(_): mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) print("mesh_shape : ", mesh_shape) print("layout_rules : ", layout_rules) print("FLAGS.gpus_per_node : ", FLAGS.gpus_per_node) print("FLAGS.gpus_per_task : ", FLAGS.gpus_per_task) print("FLAGS.tasks_per_node : ", FLAGS.tasks_per_node) # Resolve the cluster from SLURM environment cluster = tf.distribute.cluster_resolver.SlurmClusterResolver( {"mesh": mesh_shape.size // FLAGS.gpus_per_task}, port_base=8822, gpus_per_node=FLAGS.gpus_per_node, gpus_per_task=FLAGS.gpus_per_task, tasks_per_node=FLAGS.tasks_per_node) cluster_spec = cluster.cluster_spec() # Create a server for all mesh members server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id) # Only he master job takes care of the graph building, # everyone else can just chill for now if cluster.task_id > 0: server.join() # Otherwise we are the main task, let's define the devices mesh_devices = [ "/job:mesh/task:%d/device:GPU:%d" % (i, j) for i in range(cluster_spec.num_tasks("mesh")) for j in range(FLAGS.gpus_per_task) ] print("List of devices", mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") # Build the model fft_err = benchmark_model(mesh) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) # Retrieve output of computation result = lowering.export_to_tf_tensor(fft_err) print('Lowering done') with tf.Session(server.target) as sess: start = time.time() err = sess.run(result) end = time.time() niter = int(100 // np.log2( FLAGS.cube_size)) #since large meshes might take a lot of time start = time.time() for i in range(niter): err = sess.run(result) end = time.time() ttime = (end - start) / niter print('Time for ', mesh_shape, ' is : ', ttime) ###Uncomment this to get the output of a profiler ## profiler = tf.profiler.Profiler(sess.graph) ## ## run_meta = tf.RunMetadata() ## err = sess.run(result, ## options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), ## run_metadata=run_meta) ## ## profiler.add_step(0, run_meta) ## ## opts = (tf.profiler.ProfileOptionBuilder( ## tf.profiler.ProfileOptionBuilder.time_and_memory()) ## .with_step(0) ## .with_timeline_output(FLAGS.output_file).build()) ## profiler.profile_graph(options=opts) ## ###This is another way of profiling ## profiler = tf.profiler.Profiler(sess.graph) ## ## run_meta = tf.RunMetadata() ## err = sess.run(result, ## options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), ## run_metadata=run_meta) ## ## # Create the Timeline object, and write it to a json ## tl = timeline.Timeline(run_meta.step_stats) ## ctf = tl.generate_chrome_trace_format() ## with open('timelinev2-%d-%d.json'%(FLAGS.cube_size, FLAGS.max_depth), 'w') as f: ## f.write(ctf) ## ## profiler.add_step(0, run_meta) ## ## logfile = str(FLAGS.output_file) + '-%d'%FLAGS.max_depth ## opts = (tf.profiler.ProfileOptionBuilder(tf.profiler.ProfileOptionBuilder.time_and_memory()) ## .with_step(0) ## .with_timeline_output(logfile) ## .with_stdout_output() ## .with_max_depth(FLAGS.max_depth).build() ## ) ## ## profiler.profile_graph(options=opts) ## print("Max absolute FFT error %f, with wall time %f" % (err, ttime)) exit(-1)
def get_layout(): return mtf.convert_to_layout_rules(FLAGS.layout)
def dalle_model_fn(features, labels, mode, params): # since we can simply infer labels here based on the input - features here are the text input, # and labels are the image input global_step = tf.train.get_global_step() # Get global step mode_str = mode_to_str(mode) # load vae in tensorflow graph before mtf vae, vae_checkpoint_path = load_vae_model(params, mode_str) initialize_vae_weights(vae_checkpoint_path) H = W = params["dataset"]["image_size"] image_seq_len = (vae.H // (2**len(vae.convblocks)))**2 // ( vae.stack_factor**2) # TODO: check this is correct batch_size = params[f"{mode_str}_batch_size"] n_channels = params.get("input_channels", 3) with tf.variable_scope("vae"): vae_logits = vae.forward(features, return_logits=True) # TODO: using argmax sampling for now, but is that optimal? tokens = tf.math.argmax(vae_logits, -1) img_tokens_reshaped = tf.cast( tf.reshape(tokens, (batch_size, image_seq_len)), tf.int32) # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) model = DALLE( n_embd=params["n_embd"], text_vocab_size=params["text_vocab_size"], image_vocab_size=params["image_vocab_size"], text_seq_len=params["text_seq_len"], image_seq_len=image_seq_len, n_layers=params["n_layers"], n_heads=params["n_heads"], batch_size=batch_size, bf_16=params["bf_16"], mode=mode_str, params=params, ) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"image_inputs": features, "text_inputs": labels} mtf_features = {} for key, x in features_dict.items(): if x is not None: if key == "text_inputs": text_tokens = tf.reshape(x, [batch_size, params["text_seq_len"]]) x = tf.concat( (text_tokens, img_tokens_reshaped + model.text_vocab_size), axis=1) mtf_shape = mtf.Shape([ model.dimensions["batch_dim"], model.dimensions["total_seq_dim"] ]) mtf_features["tokens"] = mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if key == "image_inputs": mtf_shape = mtf.Shape([ model.dimensions["batch_dim"], mtf.Dimension("img_height_dim", vae.H), mtf.Dimension("img_width_dim", vae.W), mtf.Dimension("img_channel_dim", vae.num_ch), ]) x = tf.reshape(x, [batch_size, H, W, n_channels]) # NHWC mtf_features["image_inputs"] = mtf.import_fully_replicated( mesh, x, mtf_shape, name=key) scalar_summary("input_image", mtf_features["image_inputs"]) if mode == tf.estimator.ModeKeys.PREDICT: raise NotImplementedError # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=model.dimensions["batch_dim"], sequence_length=model.total_seq_dim, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): loss, loss_batch = model.forward(mtf_features, return_loss=True) return {"loss": loss, "loss_batch": loss_batch} # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, model.dimensions["batch_dim"], num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] else: loss, loss_batch = model.forward(mtf_features, return_loss=True) del loss_batch # TODO: may need this for some metrics - otherwise, remove from output if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if num_microbatches > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=model.variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=model.variable_dtype) # Log summaries to tensorboard scalar_summary("loss", loss) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=params.get( "max_checkpoints", 5), keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=None)
def run(tpu_job_name, tpu, gcp_project, tpu_zone, model_dir, model_type="bitransformer", vocabulary=gin.REQUIRED, train_dataset_fn=None, eval_dataset_fn=None, dataset_split="train", autostack=True, checkpoint_path="", mode="train", iterations_per_loop=100, save_checkpoints_steps=1000, eval_steps=10, train_steps=1000000, batch_size=auto_batch_size, sequence_length=gin.REQUIRED, mesh_shape=gin.REQUIRED, layout_rules=gin.REQUIRED, get_components_fn=None): """Run training/eval/inference. Args: tpu_job_name: string, name of TPU worker binary tpu: string, the Cloud TPU to use for training gcp_project: string, project name for the Cloud TPU-enabled project tpu_zone: string, GCE zone where the Cloud TPU is located in model_dir: string, estimator model_dir model_type: a string - either "bitransformer", "lm" or "aligned" vocabulary: a vocabulary.Vocabulary train_dataset_fn: A function returning a tf.data.Dataset. Must be provided for mode=train eval_dataset_fn: A function returning a tf.data.Dataset. Must be provided for model=eval dataset_split: a string autostack: boolean, internally combine variables checkpoint_path: a string - which checkpoint to load for inference mode: string, train/evaluate/infer iterations_per_loop: integer, steps per train loop save_checkpoints_steps: integer, steps per checkpoint eval_steps: integer, number of evaluation steps train_steps: Total number of training steps. batch_size: An integer or a function with the same signature as auto_batch_size(). Mini-batch size for the training. Note that this is the global batch size and not the per-shard batch size. sequence_length: an integer mesh_shape: an input to mtf.convert_to_shape() layout_rules: an input to mtf.convert_to_layout_rules() get_components_fn: an optional function that gets a list of tuples of (metric_names, component) for each component. Required if mode is "continuous_eval" """ if not isinstance(batch_size, int): batch_size = batch_size(sequence_length, mesh_shape, layout_rules) tf.logging.info("mode=%s" % mode,) tf.logging.info("batch_size=%s" % batch_size,) tf.logging.info("sequence_length=%s" % sequence_length,) tf.logging.info("mesh_shape=%s" % mesh_shape,) tf.logging.info("layout_rules=%s" % layout_rules,) if mode == "train" and dataset_split != "train": raise ValueError("mode==\"train\" requires dataset_split==\"train\"") mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout_rules) cluster = tf.contrib.cluster_resolver.TPUClusterResolver( tpu if (tpu) else "", zone=tpu_zone, project=gcp_project) tf.logging.info( "Building TPUConfig with tpu_job_name={}".format(tpu_job_name) ) my_tpu_config = tpu_config.TPUConfig( tpu_job_name=tpu_job_name, iterations_per_loop=iterations_per_loop, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST, ) run_config = tpu_config.RunConfig( cluster=cluster, model_dir=model_dir, save_checkpoints_steps=save_checkpoints_steps, tpu_config=my_tpu_config) transformer_model = build_model( model_type=model_type, vocab_size=vocabulary.vocab_size, layout_rules=layout_rules, mesh_shape=mesh_shape) model_fn = tpu_estimator_model_fn( model_type=model_type, transformer_model=transformer_model, model_dir=model_dir, use_tpu=tpu, mesh_shape=mesh_shape, layout_rules=layout_rules, batch_size=batch_size, sequence_length=sequence_length, autostack=autostack, metric_names=None) estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size, use_tpu=tpu, export_to_tpu=False, params={}) if mode == "train": if train_dataset_fn is None: raise ValueError("Must provide train_dataset_fn through gin for train.") def input_fn(params): del params dataset = train_dataset_fn(batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split) return dataset estimator.train(input_fn=input_fn, max_steps=train_steps) elif mode == "continuous_eval": if get_components_fn is None: raise ValueError("Must provide get_components_fn through gin for eval.") if eval_dataset_fn is None: raise ValueError("Must provide eval_dataset_fn through gin for eval.") metrics_inputs = get_components_fn() for _ in tf.contrib.training.checkpoints_iterator(estimator.model_dir): for metric_names, component in metrics_inputs: tf.logging.info("Evaluating {}".format(component.__dict__)) tf.logging.info("on split {}".format(dataset_split)) # Prepend eval tag and split name to metric names metric_names = [ "eval/{}/{}".format(dataset_split, n) for n in metric_names ] # Regenerate the estimator model_fn = tpu_estimator_model_fn( model_type=model_type, transformer_model=transformer_model, model_dir=model_dir, use_tpu=tpu, mesh_shape=mesh_shape, layout_rules=layout_rules, batch_size=batch_size, sequence_length=sequence_length, autostack=autostack, metric_names=metric_names) estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size, use_tpu=tpu, export_to_tpu=False, params={}) def input_fn(params): del params dataset = eval_dataset_fn(component, # pylint: disable=cell-var-from-loop batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split, pack=False) return dataset eval_args = {"eval": (input_fn, eval_steps)} _ = evaluate(estimator, eval_args) elif mode == "infer": decode_from_file( estimator, vocabulary=vocabulary, model_type=model_type, batch_size=batch_size, sequence_length=sequence_length, checkpoint_path=checkpoint_path) else: raise ValueError( "unknown mode %s - must be train/evaluate/continuous_eval/infer" % mode)
def run(tpu_job_name, tpu, gcp_project, tpu_zone, model_dir, model_type="bitransformer", vocabulary=gin.REQUIRED, train_dataset_fn=None, eval_dataset_fn=None, dataset_split="train", autostack=True, checkpoint_path="", mode="train", iterations_per_loop=100, save_checkpoints_steps=1000, keep_checkpoint_max=10, batch_size=("tokens_per_replica", 2048), train_steps=auto_train_steps, sequence_length=gin.REQUIRED, mesh_shape=gin.REQUIRED, layout_rules=gin.REQUIRED, num_eval_examples=None, get_components_fn=None, compute_metrics_from_file_fn=None, learning_rate_schedule=None, optimizer=None): """Run training/eval/inference. Args: tpu_job_name: string, name of TPU worker binary tpu: string, the Cloud TPU to use for training gcp_project: string, project name for the Cloud TPU-enabled project tpu_zone: string, GCE zone where the Cloud TPU is located in model_dir: string, estimator model_dir model_type: a string - either "bitransformer", "bi_student_teacher", lm" or "aligned" vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, targets_vocabulary) tuple. train_dataset_fn: A function returning a tf.data.Dataset. Must be provided for mode=train eval_dataset_fn: A function returning a tf.data.Dataset. Must be provided for model=eval dataset_split: a string autostack: boolean, internally combine variables checkpoint_path: a string - which checkpoint to load for inference mode: string, train/evaluate/infer iterations_per_loop: integer, steps per train loop save_checkpoints_steps: integer, steps per checkpoint keep_checkpoint_max: an integer, keep up to this many checkpoints batch_size: An integer or a (method, value) pair to pass to compute_batch_size(). Note that this is the global batch size and not the per-shard batch size. train_steps: An integer or a function with the same signature as auto_train_steps(). Total number of training steps. sequence_length: an integer mesh_shape: an input to mtf.convert_to_shape() layout_rules: an input to mtf.convert_to_layout_rules() num_eval_examples: maximum number of examples per task to use for continuous eval. get_components_fn: an optional function that returns a list of tuples of (metric_names, component) for each component. Required if mode is "continuous_eval." compute_metrics_from_file_fn: an optional function that takes in: component, metric names (list of strs), targets (list of strs), predictions (list of strs), dataset_split (str), and tb_summary_dir (str), runs metrics on targets and predictions, and returns a dictionary of metrics and their computed values. Required if mode is "continuous_eval." learning_rate_schedule: an optional function taking the scalar name argument `step` and the numeric argument `total_train_steps` and return the scalar learning rate optimizer: a class extending optimize.Optimizer, required for training """ if not isinstance(batch_size, int): batch_size = compute_batch_size(sequence_length, mesh_shape, layout_rules, batch_size) if not isinstance(train_steps, int): train_steps = train_steps(batch_size, sequence_length) if callable(learning_rate_schedule): learning_rate_schedule = functools.partial( learning_rate_schedule, total_train_steps=train_steps) tf.logging.info("model_type=%s" % model_type, ) tf.logging.info("mode=%s" % mode, ) tf.logging.info("sequence_length=%s" % sequence_length, ) tf.logging.info("batch_size=%s" % batch_size, ) tf.logging.info("train_steps=%s" % train_steps, ) tf.logging.info("mesh_shape=%s" % mesh_shape, ) tf.logging.info("layout_rules=%s" % layout_rules, ) if mode == "train" and dataset_split != "train": raise ValueError("mode==\"train\" requires dataset_split==\"train\"") mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout_rules) cluster = tf.contrib.cluster_resolver.TPUClusterResolver( tpu if (tpu) else "", zone=tpu_zone, project=gcp_project) tf.logging.info( "Building TPUConfig with tpu_job_name={}".format(tpu_job_name)) my_tpu_config = tpu_config.TPUConfig( tpu_job_name=tpu_job_name, iterations_per_loop=iterations_per_loop, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST, ) run_config = tpu_config.RunConfig( cluster=cluster, model_dir=model_dir, tpu_config=my_tpu_config, # We use a saver hook, so disable checkpoints here to prevent double # saving. save_checkpoints_steps=None, save_checkpoints_secs=None) transformer_model = build_model( model_type=model_type, input_vocab_size=inputs_vocabulary(vocabulary).vocab_size, output_vocab_size=targets_vocabulary(vocabulary).vocab_size, layout_rules=layout_rules, mesh_shape=mesh_shape) model_fn = tpu_estimator_model_fn( model_type=model_type, transformer_model=transformer_model, model_dir=model_dir, use_tpu=tpu, mesh_shape=mesh_shape, layout_rules=layout_rules, batch_size=batch_size, sequence_length=sequence_length, autostack=autostack, learning_rate_schedule=learning_rate_schedule, keep_checkpoint_max=keep_checkpoint_max, save_checkpoints_steps=save_checkpoints_steps, optimizer=optimizer) estimator = tpu_estimator.TPUEstimator(model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size, use_tpu=tpu, export_to_tpu=False, params={}) if mode == "train": if train_dataset_fn is None: raise ValueError( "Must provide train_dataset_fn through gin for train.") def input_fn(params): del params dataset = train_dataset_fn(batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split) return dataset estimator.train(input_fn=input_fn, max_steps=train_steps) elif mode == "continuous_eval": if eval_dataset_fn is None: raise ValueError( "Must provide eval_dataset_fn through gin for eval.") if get_components_fn is None: raise ValueError( "Must provide get_components_fn through gin for eval.") if compute_metrics_from_file_fn is None: raise ValueError( "Must provide compute_metrics_from_file_fn through gin for eval." ) metrics_inputs = get_components_fn() for ckpt in tf.contrib.training.checkpoints_iterator( estimator.model_dir): for metric_names, component in metrics_inputs: if not metric_names: tf.logging.info("Skipping %s", component.__dict__) continue tf.logging.info("Evaluating %s on metrics %s", component.tfds_name, component.metric_names) tf.logging.info("on split %s", dataset_split) # Regenerate the estimator model_fn = tpu_estimator_model_fn( model_type=model_type, transformer_model=transformer_model, model_dir=model_dir, use_tpu=tpu, mesh_shape=mesh_shape, layout_rules=layout_rules, batch_size=batch_size, sequence_length=sequence_length, autostack=autostack, keep_checkpoint_max=keep_checkpoint_max, save_checkpoints_steps=save_checkpoints_steps) estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size, use_tpu=tpu, export_to_tpu=False, params={}) # Extra eval_dataset_fn call to get the dataset_size and an extra # dataset object to write out targets. We need to use a separate graph # because estimator finalizes the default graph after iterating over the # dataset. dataset_graph = tf.Graph() with dataset_graph.as_default(): dataset, dataset_size, padded_dataset_size = eval_dataset_fn( component, # pylint: disable=cell-var-from-loop batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split, pack=False, max_dataset_size=num_eval_examples) def input_fn(params): del params dataset, _, _ = eval_dataset_fn( component, # pylint: disable=cell-var-from-loop batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split, pack=False, max_dataset_size=num_eval_examples) return dataset dataset_name = component.tfds_name.replace("/", "-").replace( ":", "-") output_filename = os.path.join( model_dir, "{}-{}-decoded".format(dataset_name, dataset_split)) pred_output_filename = output_filename + "-preds-test" target_output_filename = output_filename + "-targets-test" decodes = decode(estimator, input_fn, dataset_size, padded_dataset_size, batch_size, vocabulary, checkpoint_path=checkpoint_path) with dataset_graph.as_default(): log_pred_target( decodes, dataset, dataset_size, vocabulary, pred_output_filename=pred_output_filename, target_output_filename=target_output_filename) tf.logging.info("Evaluating metrics: {}".format(metric_names)) tb_summary_dir = os.path.join( model_dir, "{}_eval".format("eval" if dataset_split == "validation" else dataset_split)) summary_writer = tf.summary.FileWriter(tb_summary_dir) _ = compute_metrics_from_file_fn(component, pred_output_filename, target_output_filename, dataset_split, tb_summary_dir, ckpt, summary_writer=summary_writer) elif mode == "infer": decode_from_file(estimator, vocabulary=vocabulary, model_type=model_type, batch_size=batch_size, sequence_length=sequence_length, checkpoint_path=checkpoint_path) else: raise ValueError( "unknown mode %s - must be train/continuous_eval/infer" % mode)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning("Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls( hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = model_backbone(features, labels, mesh) variables = graph._all_variables for v in variables: logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format( v.name, v.shape, v.dtype.master_dtype)) mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape) # layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) mesh_shape = mtf.convert_to_shape(mesh_shape) estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss]) optimizer = layout_optimizer.LayoutOptimizer(estimator, scheduler_alg="NAIVE") layout_rules = mtf.convert_to_layout_rules(optimizer.solve()) logger.info("[auto mtf search] strategy: {}".format(layout_rules)) mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.SgdOptimizer(0.01) # optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) predicts = tf.sigmoid(tf_logits) # predict = lowering.export_to_tf_tensor(predicts) predicts = tf.where(predicts > 0.5, tf.ones_like(predicts), tf.zeros_like(predicts)) # print("="*100) # print(labels.shape) # print(predicts.shape) accuracy = tf.metrics.accuracy(labels=labels, predictions=predicts) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") logging_hook = tf.train.LoggingTensorHook(every_n_iter=100, tensors={ 'loss': 'cross_entropy', 'acc': 'train_accuracy' }) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, logging_hook])
def run(tpu_job_name, tpu, gcp_project, tpu_zone, model_dir, model_type="bitransformer", vocabulary=gin.REQUIRED, train_dataset_fn=None, eval_dataset_fn=None, dataset_split="train", autostack=True, checkpoint_step=None, mode="train", iterations_per_loop=100, save_checkpoints_steps=1000, keep_checkpoint_max=10, eval_summary_dir=None, batch_size=("tokens_per_replica", 2048), train_steps=auto_train_steps, sequence_length=gin.REQUIRED, mesh_shape=gin.REQUIRED, layout_rules=gin.REQUIRED, learning_rate_schedule=None, optimizer=None, predict_fn=None): """Run training/eval/inference. Args: tpu_job_name: string, name of TPU worker binary tpu: string, the Cloud TPU to use for training gcp_project: string, project name for the Cloud TPU-enabled project tpu_zone: string, GCE zone where the Cloud TPU is located in model_dir: string, estimator model_dir model_type: a string - either "bitransformer", "bi_student_teacher", lm" or "aligned" vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, targets_vocabulary) tuple. train_dataset_fn: A function returning a tf.data.Dataset. Must be provided for mode="train". Should accept the following arguments: - batch_size: int, number of entries in each batch. - sequence_length: int, length of each packed or padded sequence. - vocabulary: Vocabulary instance to use for encoding. - dataset_split: str, which dataset split to load. eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples. Must be provided for mode="eval". Should accept the following arguments: - batch_size: int, number of entries in each batch. - sequence_length: int, length of each packed or padded sequence. - vocabulary: Vocabulary instance to use for encoding. - dataset_split: str, which dataset split to load. dataset.EvalDataset tuples are namedtuples with the following fields: - name: string, the task name - dataset_fn: function which returns a tf.data.Dataset of tokenized and padded examples. Must not require any arguments and must include the feature keys 'inputs' and 'targets_plaintext'. - postprocess_fn: function which converts model outputs to evalable str - list_of_metric_fns: list of metric functions with the call signature `metric_fn(targets, predictions)` which return either a scalar value or a dict mapping submetric names to scalar values. TensorBoard summaries and other tags will be written out using `metric_fn.__name__`. - dataset_size: number of entries in the dataset. - padded_dataset_size: number of entries in the dataset after padding. dataset_split: a string autostack: boolean, internally combine variables checkpoint_step: int, list of ints, or None. Only used when mode="eval" or mode="infer". If an int or list of ints, evaluation or inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None and mode="eval", run eval continuously waiting for new checkpoints via `tf.contrib.training.checkpoints_iterator`. mode: string, train/eval/infer iterations_per_loop: integer, steps per train loop save_checkpoints_steps: integer, steps per checkpoint keep_checkpoint_max: an integer, keep up to this many checkpoints eval_summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. batch_size: An integer or a (method, value) pair to pass to compute_batch_size(). Note that this is the global batch size and not the per-shard batch size. train_steps: An integer or a function with the same signature as auto_train_steps(). Total number of training steps. sequence_length: an integer mesh_shape: an input to mtf.convert_to_shape() layout_rules: an input to mtf.convert_to_layout_rules() learning_rate_schedule: an optional function taking the scalar name argument `step` and the numeric argument `total_train_steps` and return the scalar learning rate optimizer: a class extending optimize.Optimizer, required for training predict_fn: an optional function that can be used to override the default transformer prediction behavior. Must return a tensor of shape [batch_dim, length_dim] that will be the prediction for each example. Must accept the following arguments: - model: a Unitransformer or Bitransformer - features: a dict representing an example. Every value will be an mtf.Tensor with shape [batch_dim, length_dim]. - variable_dtype: an mtf.VariableDType """ if not isinstance(batch_size, int): batch_size = compute_batch_size( sequence_length, mesh_shape, layout_rules, batch_size) if not isinstance(train_steps, int): train_steps = train_steps(batch_size, sequence_length) if callable(learning_rate_schedule): learning_rate_schedule = functools.partial( learning_rate_schedule, total_train_steps=train_steps) tf.logging.info("model_type=%s" % model_type,) tf.logging.info("mode=%s" % mode,) tf.logging.info("sequence_length=%s" % sequence_length,) tf.logging.info("batch_size=%s" % batch_size,) tf.logging.info("train_steps=%s" % train_steps,) tf.logging.info("mesh_shape=%s" % mesh_shape,) tf.logging.info("layout_rules=%s" % layout_rules,) if mode == "train" and dataset_split != "train": raise ValueError("mode==\"train\" requires dataset_split==\"train\"") mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout_rules) cluster = tf.contrib.cluster_resolver.TPUClusterResolver( tpu if (tpu) else "", zone=tpu_zone, project=gcp_project) tf.logging.info( "Building TPUConfig with tpu_job_name={}".format(tpu_job_name) ) my_tpu_config = tpu_config.TPUConfig( tpu_job_name=tpu_job_name, iterations_per_loop=iterations_per_loop, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST, ) run_config = tpu_config.RunConfig( cluster=cluster, model_dir=model_dir, tpu_config=my_tpu_config, # We use a saver hook, so disable checkpoints here to prevent double # saving. save_checkpoints_steps=None, save_checkpoints_secs=None) transformer_model = build_model( model_type=model_type, input_vocab_size=inputs_vocabulary(vocabulary).vocab_size, output_vocab_size=targets_vocabulary(vocabulary).vocab_size, layout_rules=layout_rules, mesh_shape=mesh_shape) model_fn = tpu_estimator_model_fn( model_type=model_type, transformer_model=transformer_model, model_dir=model_dir, use_tpu=tpu, mesh_shape=mesh_shape, layout_rules=layout_rules, batch_size=batch_size, sequence_length=sequence_length, autostack=autostack, learning_rate_schedule=learning_rate_schedule, keep_checkpoint_max=keep_checkpoint_max, save_checkpoints_steps=save_checkpoints_steps, optimizer=optimizer, predict_fn=predict_fn) estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size, use_tpu=tpu, export_to_tpu=False, params={}) if mode == "train": if train_dataset_fn is None: raise ValueError("Must provide train_dataset_fn through gin for train.") def input_fn(params): del params dataset = train_dataset_fn(batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split) return dataset estimator.train(input_fn=input_fn, max_steps=train_steps) elif mode == "eval": if eval_dataset_fn is None: raise ValueError("Must provide eval_dataset_fn through gin for eval.") eval_datasets = eval_dataset_fn( batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split, ) # Pre-load in all of the targets once before entering continuous eval loop cached_targets = {} # Need to create a separate graph for loading in plaintext targets # or else TF will complain that we modified the graph with tf.Graph().as_default(): for eval_dataset in eval_datasets: eval_dataset = transformer_dataset.EvalDataset(*eval_dataset) # Only cache targets for those tasks with eval functions provides if eval_dataset.metric_fns: ds = eval_dataset.dataset_fn() # De-batch the dataset ds = ds.flat_map(tf.data.Dataset.from_tensor_slices) ds = tfds.as_numpy(ds) targets = [ eval_dataset.postprocess_fn(d["targets_plaintext"]) for d in ds ] targets = targets[:eval_dataset.dataset_size] cached_targets[eval_dataset.name] = targets for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir): for eval_dataset in eval_datasets: eval_dataset = transformer_dataset.EvalDataset(*eval_dataset) if not eval_dataset.metric_fns: tf.logging.info( "Skipping %s because metric_fns is empty", eval_dataset.name ) continue metric_names = [metric.__name__ for metric in eval_dataset.metric_fns] tf.logging.info( "Evaluating %s on metrics %s", eval_dataset.name, metric_names ) tf.logging.info("on split %s", dataset_split) def input_fn(params): del params ds = eval_dataset.dataset_fn() # Only pass those variables which will be used for decoding ds = ds.map( lambda x: {k: v for k, v in x.items() if k in _INPUT_FEATURES} ) return ds decodes = decode( estimator, input_fn, eval_dataset.dataset_size, eval_dataset.padded_dataset_size, batch_size, vocabulary, checkpoint_path=checkpoint_path, ) predictions = [eval_dataset.postprocess_fn(d) for d in decodes] # TODO(craffel): Log predictions and targets. eval_summary_dir = eval_summary_dir or os.path.join( model_dir, "{}_eval".format(dataset_split) ) summary_writer = tf.summary.FileWriter(eval_summary_dir) global_step = int(get_step_from_checkpoint_path(checkpoint_path)) for metric_fn in eval_dataset.metric_fns: summary = tf.Summary() tag = "eval/{}/{}/{}".format( eval_dataset.name, dataset_split, metric_fn.__name__ ) targets = cached_targets[eval_dataset.name] metric_result = metric_fn(targets, predictions) if isinstance(metric_result, dict): tags = ["{}.{}".format(tag, key) for key in metric_result] metric_values = metric_result.values() else: tags, metric_values = [tag], [metric_result] for tag, metric_value in zip(tags, metric_values): tf.logging.info( "%s at step %d: %.3f", tag, global_step, metric_value ) summary.value.add(tag=tag, simple_value=metric_value) summary_writer.add_summary(summary, global_step) summary_writer.flush() elif mode == "infer": for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir): decode_from_file( estimator, vocabulary=vocabulary, model_type=model_type, batch_size=batch_size, sequence_length=sequence_length, checkpoint_path=checkpoint_path) else: raise ValueError( "unknown mode %s - must be train/eval/infer" % mode)
def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"inputs": features, "labels": labels} sequence_length_dict = { "inputs": params["n_ctx"], "labels": params["n_ctx"] } params = add_mode_to_params(params, mode) batch_size = get_batch_size(params) batch_dim = mtf.Dimension("batch", batch_size) batch_dims = [batch_dim] feature_length = sequence_length_dict["inputs"] length_dim = mtf.Dimension("sequence", feature_length) mtf_features = {} for key, x in features_dict.items(): if x is not None: feature_shape = mtf.Shape(batch_dims + [length_dim]) if type(features_dict[key]) == dict: features_dict[key] = features_dict[key]["feature"] x = tf.cast(features_dict[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model other_features = {} memory_length_dim = mtf.Dimension("memory_length", length_dim.size) attn_bias = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None # Add attn_bias into mtf_features other_features["attn_bias"] = attn_bias # Define other Dimensions that we'll need inside the model embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) # We need this because gathering when both the args have the same dimension in them breaks things # This dim is specifically for the weights # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["inputs"] if params["remove_partial_sequences"] is None: params["remove_partial_sequences"] = False export = params.get("export", False) if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"]) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): mtf_samples, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL] if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=batch_dim, sequence_length=sequence_length_dict, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) return { "logits": logits, "loss": loss, "loss_batch": loss_batch } else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] logits = output_dict["logits"] else: # If we're not splitting into microbatches, return logits & loss as is if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Auto layout generation if params["auto_layout"]: auto_layout(graph, mesh_shape, logits, loss) if params["auto_layout_and_mesh_shape"]: auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype) # Log summaries to tensorboard mtf.scalar_summary("loss", loss) # Log gradients if in params if params["log_grads"] not in [None, False]: for g in var_grads: grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) else: # For now, we can only export fully-replicated tensors. # This has to be done before lowering or they will not be included in the graph mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) max_logits = mtf.argmax(logits, vocab_dim) del logits fully_replicated_mean_logits = mtf.anonymize(mean_logits) fully_replicated_max_logits = mtf.anonymize(max_logits) fully_replicated_loss_batch = mtf.anonymize(loss_batch) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step tf.logging.info(f"tf_update_ops: {tf_update_ops}") train_op = tf.group(tf_update_ops) else: tf_mean_logits = lowering.export_to_tf_tensor( fully_replicated_mean_logits) tf_max_logits = lowering.export_to_tf_tensor( fully_replicated_max_logits) tf_loss_batch = tf.to_float( lowering.export_to_tf_tensor(fully_replicated_loss_batch)) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # Evaluation metrics def _perplexity(loss): perplexity = tf.exp(loss) return tf.metrics.mean(perplexity) def _bits_per_byte(loss): bpb = loss * (0.29335 / math.log(2)) return tf.metrics.mean(bpb) def _metric_fn(tf_mean_logits, tf_loss_batch): mean_logits = tf.metrics.mean(tf_mean_logits) loss = tf.reduce_mean(tf_loss_batch) perp = _perplexity(loss) bpb = _bits_per_byte(loss) return { "mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb } def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): eos_token = params["eos_id"] answer_positions = tf.where( tf.math.not_equal(labels, eos_token)) correct_answers = tf.gather_nd( tf.math.equal(tf_max_logits, labels), answer_positions) accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) # I guess tf_loss_batch has z_loss and maybe other stuff added to it # so maybe this should be calculated separately in the future answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) log_perplexity = tf.metrics.mean(answer_loss) return { "lambada_acc": accuracy, "lambada_log_ppl": log_perplexity } eval_task = params["eval_task"] if eval_task == "lambada": eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) else: eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def run(tpu_job_name, data_dir, master_dtype, slice_dtype, activation_dtype, tpu, gcp_project, tpu_zone, autostack, model_dir, mode=gin.REQUIRED, iterations_per_loop=gin.REQUIRED, save_checkpoints_steps=gin.REQUIRED, eval_steps=gin.REQUIRED, train_steps=gin.REQUIRED, batch_size=gin.REQUIRED, text2self=gin.REQUIRED, dataset=gin.REQUIRED): """Run training/eval/inference. Args: tpu_job_name: string, name of TPU worker binary data_dir: string, data_dir for TensorFlow Datasets master_dtype: string, datatype for checkpoints slice_dtype: string, datatype for variables in memory activation_dtype: string, datatype for activations tpu: string, the Cloud TPU to use for training gcp_project: string, project name for the Cloud TPU-enabled project tpu_zone: string, GCE zone where the Cloud TPU is located in autostack: boolean, internally combine variables model_dir: string, estimator model_dir mode: string, train/evaluate/infer iterations_per_loop: integer, steps per train loop save_checkpoints_steps: integer, steps per checkpoint eval_steps: integer, number of evaluation steps train_steps: Total number of training steps. batch_size: Mini-batch size for the training. Note that this is the global batch size and not the per-shard batch. text2self: Whether to train a language model (True) or encoder-decoder text-to-text model (False). dataset: TensorFlow Datasets dataset name. """ cluster = tf.contrib.cluster_resolver.TPUClusterResolver( tpu if (tpu) else "", zone=tpu_zone, project=gcp_project) my_tpu_config = tpu_config.TPUConfig( tpu_job_name=tpu_job_name, iterations_per_loop=iterations_per_loop, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST, ) run_config = tpu_config.RunConfig( cluster=cluster, model_dir=model_dir, save_checkpoints_steps=save_checkpoints_steps, tpu_config=my_tpu_config) dataset = transformer_dataset.TokenizedTFDSDataset(dataset, text2self=text2self, data_dir=data_dir or None) output_encoder = dataset.encoders["targets"] if text2self: input_encoder = output_encoder else: input_encoder = dataset.encoders["inputs"] transformer_model = model( input_vocab_size=transformer_dataset.padded_vocab_size( input_encoder.vocab_size, 128), output_vocab_size=transformer_dataset.padded_vocab_size( output_encoder.vocab_size, 128), text2self=text2self) mesh_shape = mtf.convert_to_shape(gin.query_parameter("model.mesh_shape")) layout_rules = mtf.convert_to_layout_rules( gin.query_parameter("model.layout")) # Data-types used for variables and activations # See comments in the FLAGS master_dtype = tf.as_dtype(master_dtype) if slice_dtype: slice_dtype = tf.as_dtype(slice_dtype) elif not tpu or mode == "train": slice_dtype = tf.float32 else: slice_dtype = tf.bfloat16 if activation_dtype: activation_dtype = tf.as_dtype(activation_dtype) else: activation_dtype = tf.bfloat16 if tpu else tf.float32 variable_dtype = mtf.VariableDType(master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=activation_dtype) length_from_config = gin.query_parameter( "model.length") or gin.query_parameter("model.max_length") model_fn = tpu_estimator_model_fn(transformer_model=transformer_model, model_dir=model_dir, use_tpu=tpu, mesh_shape=mesh_shape, layout_rules=layout_rules, text2self=text2self, variable_dtype=variable_dtype, batch_size=batch_size, length=length_from_config, autostack=autostack) estimator = tpu_estimator.TPUEstimator(model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size, use_tpu=tpu, export_to_tpu=False, params={}) def input_fn(params): del params return dataset.load(batch_size=batch_size, length=length_from_config, train=(mode == "train"), pack=True) if mode == "train": estimator.train(input_fn=input_fn, max_steps=train_steps) elif mode == "evaluate": estimator.evaluate( input_fn=input_fn, steps=eval_steps, ) elif mode == "infer": decode_from_file(estimator, batch_size=batch_size, length=length_from_config, inputs_encoder=dataset. encoders["targets" if text2self else "inputs"], targets_encoder=dataset.encoders["targets"], text2self=text2self) else: raise ValueError("unknown mode %s - must be train/evaluate/infer" % mode)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len( data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list,) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) unique_ids = features["unique_ids"] input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (start_logits, end_logits) = create_model( bert_config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, segment_ids=mtf_segment_ids) if mode == tf.estimator.ModeKeys.TRAIN: def compute_loss(logits, positions): one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim) log_probs = mtf.log_softmax(logits, seq_dim) loss = -mtf.reduce_mean( mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim)) return loss start_positions = features["start_positions"] mtf_start_positions = mtf.import_tf_tensor(mesh, start_positions, [batch_dim]) end_positions = features["end_positions"] mtf_end_positions = mtf.import_tf_tensor(mesh, end_positions, [batch_dim]) start_loss = compute_loss(start_logits, mtf_start_positions) end_loss = compute_loss(end_logits, mtf_end_positions) total_loss = (start_loss + end_loss) / 2.0 _, update_ops = optimization_lib.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, max_optimized_variable_size=FLAGS.max_optimized_variable_size, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) elif mode == tf.estimator.ModeKeys.PREDICT: start_logits = mtf.anonymize(start_logits) end_logits = mtf.anonymize(end_logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) if mode == tf.estimator.ModeKeys.TRAIN: tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) global_step = tf.train.get_global_step() tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = bert_lib.get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( mode, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: predictions = { "unique_ids": unique_ids, "start_logits": lowering.export_to_tf_tensor(start_logits), "end_logits": lowering.export_to_tf_tensor(end_logits), } return tf.estimator.tpu.TPUEstimatorSpec( mode, prediction_hooks=[restore_hook], predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode))
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout), mesh_devices, params['context'].device_assignment) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logitss = tf.metrics.mean(tf_logits) return {'mean_logitss': mean_logitss} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def testConvertToLayoutRules(self, inputs): layout_rules = mtf.convert_to_layout_rules(inputs) self.assertEqual( layout_rules._pairs, mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_size = mesh_shape.size mesh_devices = [""] * mesh_size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def testConvertToLayoutRulesGenericInputs(self): with self.assertRaises(ValueError): mtf.convert_to_layout_rules("d_ff;heads")
def __init__( self, model_dir, tpu, tpu_job_name=None, tpu_zone=None, gcp_project=None, tpu_topology="2x2", model_parallelism=8, batch_size=("tokens_per_batch", 1024), sequence_length=None, model_type="bitransformer", layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch", autostack=True, learning_rate_schedule=None, keep_checkpoint_max=None, save_checkpoints_steps=5000, optimizer=None, predict_fn=None, variable_filter=None, ensemble_inputs=None, iterations_per_loop=100): """Constructor for MtfModel class. Args: model_dir: str, directory to save the model. tpu: str, the TPU address to use. tpu_job_name: str, name of the TPU worker binary. tpu_zone: str, GCE zone where the Cloud TPU is located gcp_project: str, project name for the Cloud TPU-enabled project. tpu_topology: str, e.g. "2x2". model_parallelism: integer, the number of cores per model replica. batch_size: An integer or a (method, value) pair to pass to compute_batch_size(). Note that this is the global batch size and not the per-shard batch size. sequence_length: an integer or a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} model_type: str, a model type from mesh tf models. layout_rules: an input to mtf.convert_to_layout_rules() autostack: boolean, internally combine variables. learning_rate_schedule: an optional function taking the scalar name argument `step` and the numeric argument `total_train_steps` and return the scalar learning rate. keep_checkpoint_max: an integer, maximum number of checkpoints to keep. save_checkpoints_steps: an integer, steps per checkpoint. optimizer: a class extending optimize.Optimizer, required for training. predict_fn: an optional function that can be used to override the default transformer prediction behavior. Must return a tensor of shape [batch_dim, length_dim] that will be the prediction for each example. Must accept the following arguments: - model: a Unitransformer or Bitransformer - features: a dict representing an example. Every value will be an mtf.Tensor with shape [batch_dim, length_dim]. - variable_dtype: an mtf.VariableDType variable_filter: a str, a variable will only be trained if its name matches this regex. If None (default), train all trainable variables. ensemble_inputs: an integer, see `train_model` docstring for details. iterations_per_loop: integer, steps per train loop """ mesh_shape = utils.tpu_mesh_shape(tpu_topology, model_parallelism) sequence_length = sequence_length or {"inputs": 512, "targets": 512} if isinstance(sequence_length, int): sequence_length = { "inputs": sequence_length, "targets": sequence_length } if not isinstance(batch_size, int): self._batch_size = utils.compute_batch_size( sequence_length, mesh_shape, layout_rules, batch_size) else: self._batch_size = batch_size self._learning_rate_schedule = ( learning_rate_schedule or learning_rate_schedules.learning_rate_schedule_noam) self._optimizer = optimizer or optimize.AdafactorOptimizer self._sequence_length = sequence_length self._model_dir = model_dir self._model_type = model_type self._ensemble_inputs = ensemble_inputs self._layout_rules = mtf.convert_to_layout_rules(layout_rules) self._mesh_shape = mtf.convert_to_shape(mesh_shape) self._autostack = autostack self._keep_checkpoint_max = keep_checkpoint_max self._save_checkpoints_steps = save_checkpoints_steps self._predict_fn = predict_fn self._variable_filter = variable_filter self._ensemble_inputs = ensemble_inputs self._iterations_per_loop = iterations_per_loop self._cluster = tf.distribute.cluster_resolver.TPUClusterResolver( tpu if (tpu) else "", zone=tpu_zone, project=gcp_project) self._tpu = tpu self._tpu_job_name = tpu_job_name self._estimator = None
def __call__(self, features, labels, mode, params): # this is the model_fn """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() # Graph setup graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(self.mesh_shape) layout_rules = mtf.convert_to_layout_rules(self.layout) if params["use_tpu"]: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * num_hosts devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memory_usage) mesh = mtf.Mesh(graph, "my_mesh", var_placer) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, devices_memory_usage) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, "my_mesh", var_placer) # RUN Model with mtf.utils.outside_all_rewrites(): logits, loss = self.model(mesh, features, params) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if self.optimizer == "Adafactor": optimizer = mtf.optimize.AdafactorOptimizer() else: assert self.optimizer == "SGD" optimizer = mtf.optimize.SgdOptimizer( learning_rate=self.learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) # covert back to tensorflow format lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) # create estimator with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True, ) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( self.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener], ) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], ) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {"mean_logits": mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics, ) elif mode == tf.estimator.ModeKeys.PREDICT: return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.PREDICT, evaluation_hooks=[restore_hook], loss=None, eval_metrics=eval_metrics, ) @property def dense_initializer(self): if self.config.initializer_range: return tf.truncated_normal_initializer( stddev=self.config.initializer_range) else: return mtf.layers.VarianceScalingInitializer(scale=0.4) @property def embedding_initializer(self): initializer = self.dense_initializer if isinstance(initializer, mtf.layers.DenseInitializer): # embedding matrix is also used as classifier weight matrix. # scale it appropriately. return initializer(reduced_dims=[self.model_dim], new_dims=[self.vocab_dim]) else: return initializer @property def num_hidden_layers(self): return self.config.num_hidden_layers def normalize(self, x, reduce_dim): return nn.layer_norm( x, reduce_dim, subtract_mean=self.config.use_bias, use_bias=self.config.use_bias, ) def model(self, mesh, x, y, params): # x :: [batch, io, vocab] if params["precision"] == "bfloat16": dtype = tf.bfloat16 # master has type float32, slice and activation have type bfloat16 variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16, tf.bfloat16) else: dtype = tf.float32 # master, slice and activate have all float16 variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) # Build the actual model batch_dim = mtf.Dimension("batch", params["batch_size"]) vocab_dim = mtf.Dimension("vocab", params["vocab_size"]) io_dim = mtf.Dimension("sequence", params["io"]) io_chan_dim = mtf.Dimension("io", params["io_channels"]) # from input to mtf x = mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, io_dim, vocab_dim])) # Embeddings with tf.variable_scope(scope="toy", default_name="seq2seq"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([vocab_dim, io_chan_dim]), initializer=self.embedding_initializer, ) word_embedding_output = mtf.gather( embedding_table, x, dim=vocab_dim, output_shape=io_chan_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. embedding_output = word_embedding_output pos_embedding = mtf.get_variable( mesh, "pos_embeddings", mtf.Shape([io_dim, io_chan_dim]), initializer=self.embedding_initializer, ) embedding_output = self.normalize(embedding_output) embedding_output = mtf.dropout( embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob, ) # shift token by pos embeddings x = word_embedding_output + pos_embedding x = mtf.cast(x, variable_dtype.activation_dtype) h = x for lnum in range(1, self.num_hidden_layers + 2): if lnum + 1 == self.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension("hidden_even", io_chan_dim) else: dim = mtf.Dimension("hidden_odd", io_chan_dim) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, name="layer_%d" % lnum, ) prediction = h # project back to token dimensions # compute the mean quare loss between the input and the output loss = mtf.reduce_mean(mtf.square(y - prediction)) return prediction, loss