def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=0.2) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] return tf.group(tf_update_ops)
def __init__(self, sess, use_tpu, mesh_shape, layout_rules): super(MeshContext, self).__init__() self._use_tpu = use_tpu self._mesh_shape = mtf.convert_to_shape(mesh_shape) self._layout_rules = layout_rules self._d_assignment = None self._num_hosts = None self._num_cores = None self._cpu_devices, self._gpu_devices = self._list_cpu_gpu_devices(sess) if self._use_tpu: topology = sess.run(tpu.initialize_system()) topo_object = tpu.Topology(serialized=topology) self._num_cores = int(np.prod(topo_object.mesh_shape)) self._num_hosts = int(topo_object.num_tasks) num_cores_per_host = int(self._num_cores // self._num_hosts) assert num_cores_per_host == int(topo_object.num_tpus_per_task) # Get a device_assignment object for mtf. self._d_assignment = device_assignment.device_assignment( topology, computation_shape=[1,] * mtf.utils.topology_rank(topology), num_replicas=self._num_cores) self._mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( self._mesh_shape, self._layout_rules, None, self._d_assignment) else: self._mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( self._mesh_shape, self._layout_rules, self._gpu_devices)
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) lr, update_ops = optimization_lib.create_optimizer( loss, 0.2, 100, 10) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append( tf.assign_add(tf.train.get_or_create_global_step(), 1)) train_op = tf.group(tf_update_ops) return lr, train_op
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'num_heads:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) batch_dim = mtf.Dimension('batch', batch_size) seq_dim = mtf.Dimension('seq', seq_length) input_ids = tf.random.uniform((batch_size, seq_length), minval=0, maxval=vocab_size, dtype=tf.int32) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) model = bert_lib.BertModel(config=bert_config, is_training=True, input_ids=mtf_input_ids, input_mask=None, token_type_ids=None) pooled = model.get_pooled_output() lowering = mtf.Lowering(graph, {mesh: mesh_impl}) return lowering.export_to_tf_tensor(pooled)
def testMinimizePeakMemoryList_ZeroUseTensor(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:4'), dtype=tf.int32, name='X') y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:3'), dtype=tf.int32, name='Y').outputs[0] mtf.BroadcastOperation(y, mtf.convert_to_shape('b:3,c:2'), name='Z') graph = graph_interface.GraphInterface(mtf_graph) schedule = list(scheduler.minimize_peak_memory(graph, 'LIST')) # When nothing is scheduled: # X frees 0 entries # Y frees -3 entries # Hence the schedule should be [X, Y, Z]. self.assertEqual(schedule, [0, 1, 2])
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels del features mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") with mtf.utils.outside_all_rewrites(): fsum = benchmark_model(mesh) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_err = tf.to_float(lowering.export_to_tf_tensor(fsum)) with mtf.utils.outside_all_rewrites(): return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err)
def testLayout(self): # Construct a Mesh TensorFlow graph and mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") # Decide on a mesh shape. mesh_shape = mtf.convert_to_shape("m1:4,m2:2") # Compute a layout based on the graph and mesh. # Note that knowing the identity of the outputs is important to the # optimization since they cannot be freed. layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone( layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
def main(_): mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) # Resolve the cluster from SLURM environment cluster = tf.distribute.cluster_resolver.SlurmClusterResolver( {"mesh": mesh_shape.size // FLAGS.gpus_per_task}, port_base=8822, gpus_per_node=FLAGS.gpus_per_node, gpus_per_task=FLAGS.gpus_per_task, tasks_per_node=FLAGS.tasks_per_node) cluster_spec = cluster.cluster_spec() # Create a server for all mesh members server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id) # Only he master job takes care of the graph building, # everyone else can just chill for now if cluster.task_id > 0: server.join() # Otherwise we are the main task, let's define the devices mesh_devices = [ "/job:mesh/task:%d/device:GPU:%d" % (i, j) for i in range(cluster_spec.num_tasks("mesh")) for j in range(FLAGS.gpus_per_node) ] print("List of devices", mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") # Build the model fft_err = benchmark_model(mesh) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) # Retrieve output of computation result = lowering.export_to_tf_tensor(fft_err) with tf.Session(server.target) as sess: start = time.time() err = sess.run(result) end = time.time() time.sleep(1) start = time.time() err = sess.run(result) end = time.time() print("Max absolute FFT error %f, with wall time %f" % (err, (end - start))) time.sleep(1) exit(0)
def get_placement_mesh(hparams): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, hparams.layout, mesh_devices) return mesh, mesh_impl
def _tensor_dim_to_mesh_dim_size(hparams, tensor_dim): """Inspect hparams to figure out how many ways tensor_dim gets split.""" layout_rules = mtf.convert_to_layout_rules(hparams.layout) mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) mesh_axis = layout_rules.tensor_dimension_to_mesh_axis(tensor_dim, mesh_shape) if mesh_axis is None: return 1 else: return mesh_shape.dims[mesh_axis].size
def testOptimizeLayoutTiebreak(self): x1 = mtf.zeros(self.mesh, "a:10,b:5") x2 = mtf.zeros(self.mesh, "b:5,c:20") mtf.einsum([x1, x2], "a:10,c:20") # Rewrite mesh_shape to have a dummy dimension. self.mesh_shape = mtf.convert_to_shape("m1:4,m2:2,m3:1") optimizer = self.get_layout_optimizer() layout = optimizer.solve() self.assertEqual(layout, "a:m2;b:m3;c:m1")
def get_placement_mesh(hparams): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, hparams.layout, mesh_devices) return mesh, mesh_impl
def testMinimizePeakMemoryList(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:3,b:4'), dtype=tf.int32, name='X').outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:4,c:5'), dtype=tf.int32, name='Y').outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,b:4,c:5'), name='Z') w = mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='W').outputs[0] mtf.BroadcastOperation(w, mtf.convert_to_shape('a:3,b:4,c:5'), name='V') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('Z:0') graph.set_tensor_final('V:0') schedule = list(scheduler.minimize_peak_memory(graph, 'LIST')) # List Scheduler prefers to schedule things that free the most memory. # When nothing is scheduled: # X frees -12 entries. # Y frees -20 entries. # After [X] scheduled: # Y frees -20 entries. # After [X, Y] scheduled: # Z frees -60 entries. # W frees -15 entries. # After [X, Y, W] scheduled: # Z frees -28 entries. # V frees -45 entries. # Hence the schedule should be [X, Y, W, Z, V]. self.assertEqual(schedule, [0, 1, 3, 2, 4])
def testReturnsTopoSort(self, scheduler_alg): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:3,b:4'), dtype=tf.int32, name='X').outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:4,c:5'), dtype=tf.int32, name='Y').outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z1') mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z2') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('Z1:0') graph.set_tensor_final('Z2:0') schedule = list(scheduler.minimize_peak_memory(graph, scheduler_alg)) self.assertCountEqual(schedule[0:2], [0, 1]) self.assertCountEqual(schedule[2:4], [2, 3])
def run_cifar(): """Run MNIST training and eval loop.""" cifar_classifier = tf.estimator.Estimator(model_fn=model_fn, model_dir=FLAGS.model_dir) dataset = cifar_dset() # Set up training and evaluation input functions. def train_input_fn(): """Prepare data for training.""" # When choosing shuffle buffer sizes, larger sizes result in better # randomness, while smaller sizes use less memory. MNIST is a small # enough dataset that we can easily shuffle the full epoch. ds = dataset.train(FLAGS.data_dir) ds_batched = ds.cache().shuffle(buffer_size=50000).batch( FLAGS.batch_size) # Iterate through the dataset a set number (`epochs_between_evals`) of times # during each training session. ds = ds_batched.repeat(FLAGS.epochs_between_evals) return ds def eval_input_fn(): return dataset.test(FLAGS.data_dir).batch( FLAGS.batch_size).make_one_shot_iterator().get_next() # Train and evaluate model. import time time_tot_start = 0 time_epoch_start = 0 time_tot_start = time.time() f = open("./Het_CNN.txt", "a+") f.write("#Filters\t#Epochs\t#Time\t#Accuracy\t#Loss\t#Shape\n") mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) mesh_size = mesh_shape.size conv_shape = [] for ep in range(FLAGS.train_epochs // FLAGS.epochs_between_evals): time_epoch_start = time.time() cifar_classifier.train(input_fn=train_input_fn, hooks=None) time_epoch_end = time.time() - time_epoch_start eval_results = cifar_classifier.evaluate(input_fn=eval_input_fn) print("\nEvaluation results:\n\t%s\n" % eval_results) print(ep, "----------->", time_epoch_end) f.write("%d\t%0.6f\t%0.6f\t%0.6f\t%s\n" % (ep, time_epoch_end, eval_results['accuracy'], eval_results['loss'], conv_shape)) time_tot_end = time.time() - time_tot_start print("Total Time ", FLAGS.train_epochs, " Epochs", time_tot_end) f.close()
def main(_): #layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)] layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"), ("ny", "col"), ("ty", "row"), ("tz", "col"), ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"), ("ny_block", "col")] mesh_impl = HvdSimdMeshImpl(mtf.convert_to_shape(mesh_shape), mtf.convert_to_layout_rules(layout_rules)) # Build the model # Create computational graphs and some initializations graph = mtf.Graph() mesh = mtf.Mesh(graph, "nbody_mesh") initial_conditions, mesh_final_field = lpt_prototype( mesh, bs=FLAGS.box_size, nc=FLAGS.nc, batch_size=FLAGS.batch_size) # Lower mesh computation lowering = mtf.Lowering(graph, {mesh: mesh_impl}) # Retrieve output of computation initc = lowering.export_to_tf_tensor(initial_conditions) result = lowering.export_to_tf_tensor(mesh_final_field) with tf.Session() as sess: start = time.time() a, c = sess.run([initc, result]) end = time.time() ttime = (end - start) print('Time for ', mesh_shape, ' is : ', ttime) if comm.rank == 0: plt.figure(figsize=(9, 3)) plt.subplot(121) plt.imshow(a[0].sum(axis=2)) plt.title('Initial Conditions') plt.subplot(122) plt.imshow(c[0].sum(axis=2)) plt.title('Mesh TensorFlow') plt.colorbar() plt.savefig("mesh_nbody_%d-row:%d-col:%d.png" % (FLAGS.nc, FLAGS.nx, FLAGS.ny)) plt.close() exit(0)
def layout(mtf_graph, mesh_shape, mtf_outputs=()): """Compute layout rules based on a computational graph and mesh shape. Args: mtf_graph: a mtf.Graph. mesh_shape: an mtf.Shape, str, or listlike of mtf.Dimension. mtf_outputs: an optional iterable of mtf.Tensor, representing the outputs of the computation. Returns: a mtf.LayoutRules """ mesh_shape = mtf.convert_to_shape(mesh_shape) estimator = memory_estimator.MemoryEstimator(mtf_graph, mesh_shape, mtf_outputs) optimizer = layout_optimizer.LayoutOptimizer(estimator) return mtf.convert_to_layout_rules(optimizer.solve())
def main(_): tf.logging.set_verbosity(tf.logging.INFO) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) # Resolve the TPU environment tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project ) run_config = tf.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, save_checkpoints_steps=None, # Disable the default saver save_checkpoints_secs=None, # Disable the default saver log_step_count_steps=100, save_summary_steps=100, tpu_config=tpu_config.TPUConfig( num_shards=mesh_shape.size, iterations_per_loop=100, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)) model = tpu_estimator.TPUEstimator( use_tpu=True, model_fn=model_fn, config=run_config, predict_batch_size=1, train_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size) def dummy_input_fn(params): dset = tf.data.Dataset.from_tensor_slices(tf.zeros(shape=[params['batch_size'],1], dtype=tf.float32)) return dset # Run evaluate loop for ever, we will be connecting to this process using a profiler for i, f in enumerate(model.predict(input_fn=dummy_input_fn)): print(i) np.save(file_io.FileIO(FLAGS.output_dir+'/field_%d.npy'%i, 'w'), f['field'])
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels del features mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") with mtf.utils.outside_all_rewrites(): field = nbody_model(mesh) batch_dim, x_dim, y_dim, z_dim = field.shape x_dim_nosplit = mtf.Dimension("nx_nosplit", FLAGS.cube_size) y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size) # Until we implement distributed outputs, we only return one example field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size - 1]) field_slice = mtf.reshape( field_slice, [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim]) #field_slice = field lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice)) with mtf.utils.outside_all_rewrites(): return tpu_estimator.TPUEstimatorSpec(mode, predictions={'field': tf_field})
def run_toy_model_tpu(): """Run a toy model on TPU.""" tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) iterations_per_loop = FLAGS.iterations mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) config = tpu_config.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, save_checkpoints_steps=None, # Disable the default saver save_checkpoints_secs=None, # Disable the default saver log_step_count_steps=iterations_per_loop, save_summary_steps=iterations_per_loop, tpu_config=tpu_config.TPUConfig( num_shards=mesh_shape.size, iterations_per_loop=iterations_per_loop, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig. BROADCAST)) classifier = tpu_estimator.TPUEstimator(use_tpu=True, model_fn=model_fn, config=config, train_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size) current_step = estimator_lib._load_global_step_from_checkpoint_dir( FLAGS.model_dir) # pylint: disable=protected-access,line-too-long logging.info('Current step %d', current_step) if FLAGS.steps_per_checkpoint == 0: classifier.train(input_fn=ToyModelInput(), max_steps=FLAGS.train_steps) return while current_step < FLAGS.train_steps: next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint, FLAGS.train_steps) classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint) current_step = next_checkpoint logging.info('Starting to evaluate.') eval_results = classifier.evaluate( input_fn=ToyModelInput(), steps=156 ) # since we have 10000 examples and batch_size = 64 per host logging.info('Eval results: %s', eval_results)
def main(_): # Creating layout and mesh implementation mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)] layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"), ("ny", "col"), ("ty", "row"), ("tz", "col"), ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"), ("ny_block", "col")] mesh_impl = HvdSimdMeshImpl( mtf.convert_to_shape(mesh_shape), mtf.convert_to_layout_rules(layout_rules)) # Create the graph and mesh graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") ## Load initial power spectrum klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] # Defines the computational graph for the nbody initial_conditions, final_field = nbody_fn(mesh, klin, plin) # Lower mesh computation lowering = mtf.Lowering(graph, {mesh: mesh_impl}) # Retrieve fields as tf tensors tf_initc = lowering.export_to_tf_tensor(initial_conditions) tf_final = lowering.export_to_tf_tensor(final_field) with tf.Session() as sess: start = time.time() init_conds, final = sess.run([tf_initc, tf_final]) end = time.time() print('\n Time for the mesh run : %f \n' % (end - start)) # Export these fields np.save('simulation_output_%d.npy' % comm.Get_rank(), final) np.save('simulation_input_%d.npy' % comm.Get_rank(), init_conds) exit(0)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) # Resolve the TPU environment tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) run_config = tf.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, save_checkpoints_steps=None, # Disable the default saver save_checkpoints_secs=None, # Disable the default saver log_step_count_steps=100, save_summary_steps=100, tpu_config=tpu_config.TPUConfig( num_shards=mesh_shape.size, iterations_per_loop=100, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)) model = tpu_estimator.TPUEstimator( use_tpu=True, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.batch_size) def dummy_input_fn(params): """Dummy input function """ return tf.zeros( shape=[params['batch_size']], dtype=tf.float32), tf.zeros( shape=[params['batch_size']], dtype=tf.float32) # Run evaluate loop for ever, we will be connecting to this process using a profiler model.evaluate(input_fn=dummy_input_fn, steps=10000)
def train_and_eval(): """Trains and evaluates MeshTensorflow model without TPUEstimator. TODO(lehou): Pack everything nicely as a set of APIs. """ tf.logging.info('FLAGS.master: {}'.format(FLAGS.master)) # Open a session to get the list of CPU devices to hold master variables. with tf.Session(target=FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)) as sess: topology = sess.run(tpu.initialize_system()) cpu_devices = _list_cpu_devices(sess) topo_object = tf.contrib.tpu.Topology(serialized=topology) num_cores = int(np.prod(topo_object.mesh_shape)) num_hosts = int(topo_object.num_tasks) num_cores_per_host = int(num_cores // num_hosts) assert num_cores_per_host == int(topo_object.num_tpus_per_task) # Get a device_assignment object for mtf. d_assignment = device_assignment.device_assignment( topology, computation_shape=[1, 1, 1], num_replicas=num_cores) # Get mesh_impl. mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = unet.get_layout() mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None, d_assignment) for _ in range(FLAGS.num_training_loops): _train_phase(mesh_impl, cpu_devices, d_assignment, num_hosts, num_cores) _eval_phase(mesh_impl, cpu_devices, d_assignment, num_hosts, num_cores) _shutdown() tf.logging.info('finished.')
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() # wrapped graph named "my_mesh" mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_size = mesh_shape.size print("mesh_shape.size = ", mesh_shape.size) mesh_devices = [""] * mesh_size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def __init__( self, model_dir, tpu, tpu_job_name=None, tpu_zone=None, gcp_project=None, tpu_topology="v2-8", model_parallelism=8, batch_size=("sequences_per_batch", 1), sequence_length=None, model_type="bitransformer", layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch", mesh_shape=None, mesh_devices=None, autostack=True, learning_rate_schedule=None, keep_checkpoint_max=None, save_checkpoints_steps=5000, optimizer=None, predict_fn=None, variable_filter=None, ensemble_inputs=None, iterations_per_loop=100): """Constructor for MtfModel class. Args: model_dir: str, directory to save the model. tpu: str, the TPU address to use. tpu_job_name: str, name of the TPU worker binary. tpu_zone: str, GCE zone where the Cloud TPU is located gcp_project: str, project name for the Cloud TPU-enabled project. tpu_topology: str, e.g. "2x2" or "v2-8". model_parallelism: integer, the number of cores per model replica. batch_size: An integer or a (method, value) pair to pass to compute_batch_size(). Note that this is the global batch size and not the per-shard batch size. sequence_length: an integer or a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} model_type: str, a model type from mesh tf models. layout_rules: an input to mtf.convert_to_layout_rules() mesh_shape: an mtf.Shape or string (e.g., "model:2,batch:4") specifying how the data/model should be parallelized. If None (default), the mesh shape will be constructed using the supplied `tpu_topology` and `model_parallelism` arguments. mesh_devices: a list of strings, the device names to use for each mesh slice. Only required for GPU. autostack: boolean, internally combine variables. learning_rate_schedule: an optional function taking the scalar name argument `step` and the numeric argument `total_train_steps` and return the scalar learning rate. keep_checkpoint_max: an integer, maximum number of checkpoints to keep. save_checkpoints_steps: an integer, steps per checkpoint. optimizer: a class extending optimize.Optimizer, required for training. predict_fn: an optional function that can be used to override the default transformer prediction behavior. Must return a tensor of shape [batch_dim, length_dim] that will be the prediction for each example. Must accept the following arguments: - model: a Unitransformer or Bitransformer - features: a dict representing an example. Every value will be an mtf.Tensor with shape [batch_dim, length_dim]. - variable_dtype: an mtf.VariableDType variable_filter: a str, a variable will only be trained if its name matches this regex. If None (default), train all trainable variables. ensemble_inputs: an integer, see `train_model` docstring for details. iterations_per_loop: integer, steps per train loop """ mesh_shape = mesh_shape or ( utils.tpu_mesh_shape(tpu_topology, model_parallelism) if tpu else "") sequence_length = sequence_length or {"inputs": 512, "targets": 512} if isinstance(sequence_length, int): sequence_length = {"inputs": sequence_length, "targets": sequence_length} self._learning_rate_schedule = ( learning_rate_schedule or learning_rate_schedules.learning_rate_schedule_noam) self._optimizer = optimizer or optimize.AdafactorOptimizer self._sequence_length = sequence_length self._model_dir = model_dir self._model_type = model_type self._ensemble_inputs = ensemble_inputs self._layout_rules = mtf.convert_to_layout_rules(layout_rules) self._mesh_shape = mtf.convert_to_shape(mesh_shape) self._mesh_devices = mesh_devices self._autostack = autostack self._keep_checkpoint_max = keep_checkpoint_max self._save_checkpoints_steps = save_checkpoints_steps self._predict_fn = predict_fn self._variable_filter = variable_filter self._ensemble_inputs = ensemble_inputs self._iterations_per_loop = iterations_per_loop self._cluster = tf.distribute.cluster_resolver.TPUClusterResolver( tpu, zone=tpu_zone, project=gcp_project) if tpu else None self._tpu = tpu self._tpu_job_name = tpu_job_name self._estimator = None # Must be called after _sequence_length, _mesh_shape, and _layout_rules are # set. self.batch_size = batch_size
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.use_tpu: ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list,) # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients([loss], [v.outputs[0] for v in graph.trainable_variables]) if FLAGS.optimizer == 'Adafactor': optimizer = mtf.optimize.AdafactorOptimizer() else: assert FLAGS.optimizer == 'SGD' optimizer = mtf.optimize.SgdOptimizer(lr=FLAGS.lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {'mean_logits': mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def test_get_laidout_tensors(self, is_eval_mode): mesh_shape = "mesh_x:2, mesh_y:1" layout = "batch:mesh_x, io:mesh_y" batch_io_dim = 4 with tf.Session() as sess: topology, num_cores = self.initialize_system(sess) # Get a device_assignment object for mtf. d_assignment = device_assignment.device_assignment( topology, computation_shape=[ 1, ] * mtf.utils.topology_rank(topology), num_replicas=num_cores) # Hacked dataset creator: creates different datasets for the first and # second call, in order to test SimdMeshImplInputReader. self.sub_batch_created_times = 0 def stateful_ds_creator(): whole_batch = tf.eye(batch_io_dim, dtype=tf.float32) sub_batch = tf.slice(whole_batch, [self.sub_batch_created_times * 2, 0], [2, 4]) self.sub_batch_created_times += 1 return tf.data.Dataset.from_tensors( sub_batch).repeat().unbatch() batch_dim = mtf.Dimension("batch", batch_io_dim) io_dim = mtf.Dimension("io", batch_io_dim) mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])] # Get mesh_impl. mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, None, d_assignment) simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_impl, stateful_ds_creator, mtf_input_shapes, external_worker=False, is_eval_mode=is_eval_mode) def model_fn(features): return features replicated_computation = tpu.replicate( computation=model_fn, inputs=[[]] * num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=d_assignment) simd_input_reader.start_infeed_thread(sess, 1) results = sess.run(replicated_computation) print("results: {}".format(results)) core_0_data = results[0][0] core_1_data = results[1][0] print("core_0_data: {}".format(core_0_data)) print("core_1_data: {}".format(core_1_data)) if is_eval_mode: # If there is only one dataset object, then the stateful_ds_creator() # should be called only once. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_1_data) else: # If there are two dataset objects, then the stateful_ds_creator() # should be called twice. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32), core_1_data) sess.run(tf.tpu.shutdown_system())
def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"inputs": features, "labels": labels} sequence_length_dict = { "inputs": params["n_ctx"], "labels": params["n_ctx"] } params = add_mode_to_params(params, mode) batch_size = get_batch_size(params) batch_dim = mtf.Dimension("batch", batch_size) batch_dims = [batch_dim] feature_length = sequence_length_dict["inputs"] length_dim = mtf.Dimension("sequence", feature_length) mtf_features = {} for key, x in features_dict.items(): if x is not None: feature_shape = mtf.Shape(batch_dims + [length_dim]) if type(features_dict[key]) == dict: features_dict[key] = features_dict[key]["feature"] x = tf.cast(features_dict[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model other_features = {} memory_length_dim = mtf.Dimension("memory_length", length_dim.size) attn_bias = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None # Add attn_bias into mtf_features other_features["attn_bias"] = attn_bias # Define other Dimensions that we'll need inside the model embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) # We need this because gathering when both the args have the same dimension in them breaks things # This dim is specifically for the weights # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["inputs"] if params["remove_partial_sequences"] is None: params["remove_partial_sequences"] = False export = params.get("export", False) if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax']) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): mtf_samples, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=batch_dim, sequence_length=sequence_length_dict, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) return { "logits": logits, "loss": loss, "loss_batch": loss_batch } else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] logits = output_dict["logits"] else: # If we're not splitting into microbatches, return logits & loss as is if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Auto layout generation if params["auto_layout"]: auto_layout(graph, mesh_shape, logits, loss) if params["auto_layout_and_mesh_shape"]: auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype) # Log summaries to tensorboard mtf.scalar_summary("loss", loss) # Log gradients if in params if params["log_grads"] not in [None, False]: for g in var_grads: grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) else: # For now, we can only export fully-replicated tensors. # This has to be done before lowering or they will not be included in the graph mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) max_logits = mtf.argmax(logits, vocab_dim) del logits fully_replicated_mean_logits = mtf.anonymize(mean_logits) fully_replicated_max_logits = mtf.anonymize(max_logits) fully_replicated_loss_batch = mtf.anonymize(loss_batch) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step tf.logging.info(f"tf_update_ops: {tf_update_ops}") train_op = tf.group(tf_update_ops) else: tf_mean_logits = lowering.export_to_tf_tensor( fully_replicated_mean_logits) tf_max_logits = lowering.export_to_tf_tensor( fully_replicated_max_logits) tf_loss_batch = tf.to_float( lowering.export_to_tf_tensor(fully_replicated_loss_batch)) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # Evaluation metrics def _perplexity(loss): perplexity = tf.exp(loss) return tf.metrics.mean(perplexity) def _bits_per_byte(loss): bpb = loss * (0.29335 / math.log(2)) return tf.metrics.mean(bpb) def _metric_fn(tf_mean_logits, tf_loss_batch): mean_logits = tf.metrics.mean(tf_mean_logits) loss = tf.reduce_mean(tf_loss_batch) perp = _perplexity(loss) bpb = _bits_per_byte(loss) return { "mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb } def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): eos_token = params["eos_id"] answer_positions = tf.where( tf.math.not_equal(labels, eos_token)) correct_answers = tf.gather_nd( tf.math.equal(tf_max_logits, labels), answer_positions) accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) # I guess tf_loss_batch has z_loss and maybe other stuff added to it # so maybe this should be calculated separately in the future answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) log_perplexity = tf.metrics.mean(answer_loss) return { "lambada_acc": accuracy, "lambada_log_ppl": log_perplexity } eval_task = params["eval_task"] if eval_task == "lambada": eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) else: eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None): hparams = copy.deepcopy(hparams) use_tpu = params and params.get("use_tpu", False) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def maybe_reshape_attention_input_for_2d_sharding( context, q, k, v, bias, unsplittable_dims): """Reshape the inputs to attention to split over an unused mesh dimension. In the case where the attention computation is unnecessarily replicated, this function reshapes the attention inputs to remove the unnecessary replication. This becomes relevent when doing 2-dimenional model parallelism. d_model is sharded over one mesh dimension and [vocab, num_heads, d_ff] are sharded over the other mesh dimension. This fully distributes all of the einsum operations, except for the internals of the attention computation. To distribute that computation, this function creates a new tensor-dimension from the low bits of either the batch dimension or the num_heads dimension, and then splits that dimension over the unused mesh dimension. Args: context: a transformer.Context q: a Tensor k: a Tensor v: a Tensor bias: a Tensor unsplittable_dims: a list of tensor-dimensions not to split. The key/value dimensions should be passed here. Returns: reshaped_q: a Tensor reshaped_k: a Tensor reshaped_v: a Tensor reshaped_bias: a Tensor """ original_inputs = q, k, v, bias # we need to know the layout and mesh-shape to figure out what to do. if not context or not context.model.layout or not context.model.mesh_shape: return original_inputs mesh_shape = mtf.convert_to_shape(context.model.mesh_shape) layout_rules = mtf.convert_to_layout_rules(context.model.layout) # find a mesh dim that is unused (no tensor-dimension is split across it) mesh_axis_used = [False] * mesh_shape.ndims for x in original_inputs: for mesh_axis in layout_rules.tensor_layout( x.shape, mesh_shape).tensor_axis_to_mesh_axis: if mesh_axis is not None: mesh_axis_used[mesh_axis] = True if False not in mesh_axis_used: return original_inputs mesh_dim = mesh_shape.dims[mesh_axis_used.index(False)] # Choose an appropriate name for the new tensor-dimension so that the layout # will know to split it across the unused mesh dimension. tensor_dim_name = None tensor_dim_name = layout_rules.mesh_dimension_name_to_tensor_dimension_names( mesh_dim.name) if tensor_dim_name: tensor_dim_name = tensor_dim_name[0] else: return original_inputs # Find a tensor-dimension that we can further split, by breaking off the # lower bits into our new tensor-dimension. # This resplittable tensor-dimension must be presnent in all of q, k, v # and must be large enough to be further split. resplittable_dim = None for d in q.shape.dims: if d in k.shape.dims and d in v.shape.dims and d not in unsplittable_dims: num_splits = mtf.tensor_dim_to_mesh_dim_size( context.model.layout, context.model.mesh_shape, d) if d.size % (num_splits * mesh_dim.size) == 0: resplittable_dim = d break if not resplittable_dim: return original_inputs new_dim_high = mtf.Dimension(resplittable_dim.name, num_splits) new_dim_low = mtf.Dimension(tensor_dim_name, resplittable_dim.size // num_splits) def _my_reshape(x): if x and resplittable_dim in x.shape.dims: return mtf.replace_dimensions( x, resplittable_dim, [new_dim_high, new_dim_low]) else: return x return _my_reshape(q), _my_reshape(k), _my_reshape(v), _my_reshape(bias)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list, ) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list(ctx.device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) max_predictions_per_seq = masked_lm_positions.get_shape()[1].value max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq", max_predictions_per_seq) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_masked_lm_positions = mtf.import_tf_tensor( mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_ids = mtf.import_tf_tensor( mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_weights = mtf.import_tf_tensor( mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim]) mtf_next_sentence_labels = mtf.import_tf_tensor( mesh, next_sentence_labels, [batch_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = bert_lib.BertModel(config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, token_type_ids=mtf_segment_ids, layout=layout_rules, mesh_shape=mesh_shape) (masked_lm_loss, masked_lm_example_loss, masked_lm_logits) = model.get_masked_lm_output( mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_logits ) = model.get_next_sentence_output(mtf_next_sentence_labels) extra_loss = model.get_extra_loss() total_loss = masked_lm_loss + next_sentence_loss total_loss = mtf.anonymize(total_loss) masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss) masked_lm_logits = mtf.anonymize(masked_lm_logits) next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss) next_sentence_logits = mtf.anonymize(next_sentence_logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss + extra_loss, learning_rate, num_train_steps, num_warmup_steps, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_logits, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_logits = tf.reshape(masked_lm_logits, [-1, masked_lm_logits.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_logits, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_logits = tf.reshape( next_sentence_logits, [-1, next_sentence_logits.shape[-1]]) next_sentence_predictions = tf.argmax(next_sentence_logits, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss, } eval_metrics = (metric_fn, [ lowering.export_to_tf_tensor(masked_lm_example_loss), lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids, masked_lm_weights, lowering.export_to_tf_tensor(next_sentence_example_loss), lowering.export_to_tf_tensor(next_sentence_logits), next_sentence_labels ]) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning("Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls( hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])