def testGraph(self): graph = mtf.Graph() self.assertLen(graph.operations, 0) self.assertLen(graph.tensors, 0) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertLen(graph.tensors, 1) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.tensors, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.tensors, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2)
def testDense(self, units, use_bias): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) depth_dim = mtf.Dimension("depth", units) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.dense(mtf_inputs, output_dim=depth_dim, reduced_dims=[channels_dim], activation=mtf.relu, use_bias=use_bias) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.keras.layers.Dense(units=units, activation=tf.nn.relu, use_bias=use_bias)(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual, expected = sess.run([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def testLayerNorm(self): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.layer_norm(mtf_inputs, dim=channels_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = common_layers.layer_norm(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual, expected = sess.run([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def testWeightsNonzero(self): inputs = tf.constant([[3, 1, 0], [1, 0, 0]]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0]) channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf_layers.weights_nonzero(mtf_inputs) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = common_layers.weights_nonzero(inputs) tf_group = lowering.copy_masters_to_slices() self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllEqual(actual, expected)
def testDenseReluDense(self): batch = 2 channels = 3 hidden = 5 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) hidden_dim = mtf.Dimension("hidden", hidden) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.dense_relu_dense(mtf_inputs, hidden_channels=hidden_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, inputs.shape)
def get_placement_mesh(hparams): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) mesh_devices = [""] * mesh_shape.size mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, hparams.layout, mesh_devices) return mesh, mesh_impl
def get_placement_mesh(hparams): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.parse_mesh_shape(hparams.mesh_shape) mesh_size = mtf.list_product(mesh_shape) mesh_devices = [""] * mesh_size mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, mtf.parse_layout(hparams.layout), mesh_devices) return mesh, mesh_impl
def testDotProductAttention(self, batch, heads, length_q, length_kv, depth_k, depth_v): query = tf.random_normal([batch, heads, length_q, depth_k]) key = tf.random_normal([batch, heads, length_kv, depth_k]) value = tf.random_normal([batch, heads, length_kv, depth_v]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) heads_dim = mtf.Dimension("heads", heads) length_q_dim = mtf.Dimension("length_q", length_q) length_kv_dim = mtf.Dimension("length_kv", length_kv) depth_k_dim = mtf.Dimension("depth_k", depth_k) depth_v_dim = mtf.Dimension("depth_v", depth_v) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, heads_dim, length_q_dim, depth_k_dim])) mtf_key = mtf.import_tf_tensor( mesh, key, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_k_dim])) mtf_value = mtf.import_tf_tensor( mesh, value, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_v_dim])) mtf_outputs = mtf_layers.dot_product_attention(mtf_query, mtf_key, mtf_value, mask=None) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))
def testLowering(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") inputs = tf.constant(0.) mtf_inputs = mtf.import_tf_tensor(mesh, tf_tensor=inputs, shape=mtf.Shape([])) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_inputs) inputs_value, outputs_value = self.evaluate([inputs, outputs]) self.assertEqual(inputs_value, outputs_value) # Check that methods run without error. _ = lowering.copy_masters_to_slices() _ = lowering.copy_slices_to_masters()
def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels, heads, block_length): length_q = length length_m = length query = tf.random_normal([batch, length_q, io_channels]) memory = tf.random_normal([batch, length_m, io_channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_q_dim = mtf.Dimension("length_q", length_q) length_m_dim = mtf.Dimension("length_m", length_m) io_channels_dim = mtf.Dimension("io_channels", io_channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim])) mtf_memory = mtf.import_tf_tensor( mesh, memory, shape=mtf.Shape([batch_dim, length_m_dim, io_channels_dim])) mtf_outputs = mtf_layers.masked_local_attention_1d( mtf_query, mtf_memory, kv_channels=kv_channels_dim, heads=heads_dim, block_length=block_length) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, (batch, length_q, io_channels))
def test_variable_placer(self): sizes = [100, 0, 0, 0] device_list = ['cpu:0', 'cpu:1', 'cpu:2', 'cpu:3'] with tf.Graph().as_default() as g: var_placer = mtf_utils.BalancedVariablePlacer(device_list, sizes) graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh', var_placer) hidden_dim = mtf.Dimension('hidden', 10) output_dim = mtf.Dimension('output_feature', 10) for i in xrange(5): # Each variable takes 400 Bytes, and will be placed from cpu:1. mtf.get_variable(mesh, 'w{}'.format(i), [hidden_dim, output_dim]) for i in xrange(5): var = g.get_tensor_by_name('w{}:0'.format(i)) device = (i + 1) % len(device_list) self.assertEqual('cpu:{}'.format(device), var.device)
def testMultiheadAttention(self, kv_channels, heads): batch = 2 length = 8 channels = 3 query = tf.random_normal([batch, length, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) channels_dim = mtf.Dimension("channels", channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_dim, channels_dim])) mtf_outputs = mtf_layers.multihead_attention( mtf_query, memory_antecedent=None, mask=None, kv_channels=kv_channels_dim, heads=heads_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, query.shape)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False, xla_compile=False): del xla_compile hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: mesh_devices = [""] * mesh_shape.size mesh_impl = simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment) else: if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf_optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf_utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: _remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = mtf.parse_mesh_shape(FLAGS.mesh_shape) mesh_size = mtf.list_product(mesh_shape) mesh_devices = [""] * mesh_size mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, mtf.parse_layout(FLAGS.layout), mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdafactorOptimizer() update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.outfeed(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.outfeed(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def testMesh(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") self.assertEqual(mesh.graph, graph)
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) mesh_devices = [''] * mesh_shape.size mesh_impl = SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout), mesh_devices, params['context'].device_assignment) with mtf_utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients([loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdafactorOptimizer() update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf_utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logitss = tf.metrics.mean(tf_logits) return {'mean_logitss': mean_logitss} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)