def testNthSmallestReduceSecondDim(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 6) b_dim = mtf.Dimension("b", 2) inputs = tf.constant([[1, 10], [2, 9], [3, 8], [4, 7], [5, 6], [6, 5]]) n = 0 # find smallest element (n is zero-indexed) reduced_dim = b_dim expected_outputs = tf.constant([1, 2, 3, 4, 5, 5]) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([a_dim, b_dim])) mtf_outputs = mtf.nth_smallest_element( mtf_inputs, n, reduced_dim, "test_nth_smallest") mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="all:2", layout="a:all", devices=["", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) self.assertAllEqual(self.evaluate(actual_outputs), self.evaluate(expected_outputs))
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) lr, update_ops = optimization_lib.create_optimizer( loss, 0.2, 100, 10) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append( tf.assign_add(tf.train.get_or_create_global_step(), 1)) train_op = tf.group(tf_update_ops) return lr, train_op
def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu): mtf_samples = self.sample(features, mesh) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_samples) if self.has_input: ndims = len(outputs.shape.as_list()) actual_batch_size = tf.shape(features["inputs"])[0] outputs = tf.slice(outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1)) predictions = { "outputs": outputs, "targets": features.get("infer_targets", features.get("inputs")), "inputs": features.get("inputs"), } if use_tpu: t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)])
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'num_heads:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) batch_dim = mtf.Dimension('batch', batch_size) seq_dim = mtf.Dimension('seq', seq_length) input_ids = tf.random.uniform((batch_size, seq_length), minval=0, maxval=vocab_size, dtype=tf.int32) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) model = bert_lib.BertModel(config=bert_config, is_training=True, input_ids=mtf_input_ids, input_mask=None, token_type_ids=None) pooled = model.get_pooled_output() lowering = mtf.Lowering(graph, {mesh: mesh_impl}) return lowering.export_to_tf_tensor(pooled)
def testRecomputeGrad(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") # let's differentiate x^2 + x # dy/dx = 2x+1 def x_squared_plus_x(x): return x * x + x x = tf.constant([5, 10], dtype=tf.float32) dy = tf.constant([2, 3], dtype=tf.float32) two = mtf.Dimension("two", 2) expected_y = tf.constant([30, 110], dtype=tf.float32) expected_dx = tf.constant([22, 63], dtype=tf.float32) mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two])) mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two])) mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x]) [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy]) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="processors:2", layout="two:processors", devices=["", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_y = lowering.export_to_tf_tensor(mtf_y) actual_dx = lowering.export_to_tf_tensor(mtf_dx) self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y)) self.assertAllEqual(self.evaluate(actual_dx), self.evaluate(expected_dx))
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=0.2) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] return tf.group(tf_update_ops)
def testDynamicText2self_unpacked(self): batch = 2 length = 5 input_tensors = { "inputs": [[3, 1, 4, 1, 0], [1, 4, 3, 2, 1]], "targets": [[1, 1, 0, 0, 0], [9, 8, 1, 2, 1]], } expected_output_tensors = { "targets": [[3, 1, 4, 1, 1, 1, 0, 0, 0, 0], [1, 4, 3, 2, 1, 9, 8, 1, 2, 1]], } graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) input_shape = mtf.Shape([batch_dim, length_dim]) mtf_features = { k: mtf.import_tf_tensor(mesh, v, input_shape) for k, v in input_tensors.items() } mtf_outputs = utils._dynamic_text2self(mtf_features) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) for k, v in expected_output_tensors.items(): out = lowering.export_to_tf_tensor(mtf_outputs[k]) actual = self.evaluate(out) self.assertAllEqual(actual, v)
def testDense(self, units, use_bias): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) depth_dim = mtf.Dimension("depth", units) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf.layers.dense(mtf_inputs, output_dim=depth_dim, reduced_dims=[channels_dim], activation=mtf.relu, use_bias=use_bias) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.keras.layers.Dense(units=units, activation=tf.nn.relu, use_bias=use_bias)(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels, heads, window_size): length_q = length query = tf.random_normal([batch, length_q, io_channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_q_dim = mtf.Dimension("length_q", length_q) io_channels_dim = mtf.Dimension("io_channels", io_channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim])) mtf_outputs = mtf.layers.masked_local_attention_1d( mtf_query, kv_channels=kv_channels_dim, heads=heads_dim, window_size=window_size) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, (batch, length_q, io_channels))
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels del features mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") with mtf.utils.outside_all_rewrites(): fsum = benchmark_model(mesh) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_err = tf.to_float(lowering.export_to_tf_tensor(fsum)) with mtf.utils.outside_all_rewrites(): return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err)
def test_hidden_to_logits_computesLogitsCorrectly(self): seq_len = 1 vocab_size = 4 model_size = 3 num_softmaxes = 2 vocab_dim = mtf.Dimension('vocab', vocab_size) model_dim = mtf.Dimension('model', model_size) length_dim = mtf.Dimension('length', seq_len) embeddings = tf.constant(np.array([[1.0, 1.0, 2.0]]) / model_size**-0.5, dtype=tf.float32) mtf_embeddings = mtf.import_tf_tensor(self.mesh, embeddings, shape=mtf.Shape( [length_dim, model_dim])) self.initializer_mock.side_effect = initialize_by_shape({ # Embedding weights. (4, 3): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]], # Mixture weights. (2, 3): [[1, 0, 0], [0, 1, 1]], # Context weights (2, 3, 3): [ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]], ], }) vocab_embedding = vocab_embeddings.MixtureOfSoftmaxes( self.mesh, vocab_dim, output_dim=model_dim, variable_dtype=self.variable_dtype, name='embedding', ensemble_dim=None, num_softmaxes=num_softmaxes) mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings, context=None) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=['']) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) actual_logits = lowering.export_to_tf_tensor(mtf_logits) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual, = self.evaluate([actual_logits]) expected_priors = scipy.special.softmax([1, 3]) expected_probs_1 = scipy.special.softmax(np.tanh([1, 1, 2, 2])) expected_probs_2 = scipy.special.softmax(np.tanh([2, 1, 1, 1])) expected_probs = (expected_priors[0] * expected_probs_1 + expected_priors[1] * expected_probs_2) expected_logits = np.log(expected_probs) self.assertAllClose(actual, [expected_logits])
def convert_mtf_tensor_to_tf_tensor(self, mtf_tensor): """Convert an mtf.Tensor to a tf.Tensor.""" mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) return lowering, lowering.export_to_tf_tensor(mtf_tensor)
def testWeightsNonzero(self): inputs = tf.constant([[3, 1, 0], [1, 0, 0]]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0]) channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.cast(tf.not_equal(inputs, 0), tf.float32) tf_group = lowering.copy_masters_to_slices() self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllEqual(actual, expected)
def testDenseReluDense(self): batch = 2 channels = 3 hidden = 5 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) hidden_dim = mtf.Dimension("hidden", hidden) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.dense_relu_dense(mtf_inputs, hidden_channels=hidden_dim, is_training=False) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, inputs.shape)
def testCorr2DInput(self): batch = 4 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tfp.stats.correlation(inputs, sample_axis=0, event_axis=1) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape) self.assertAllClose(actual, expected)
def testMultiheadAttention(self, kv_channels, heads): batch = 2 length = 8 channels = 3 query = tf.random_normal([batch, length, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) channels_dim = mtf.Dimension("channels", channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_dim, channels_dim])) mtf_outputs = mtf.layers.multihead_attention( mtf_query, memory_antecedent=None, mask=None, kv_channels=kv_channels_dim, heads=heads_dim) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, query.shape)
def testConv1d(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") filter_size = 3 depth_dim = mtf.Dimension("depth", 2) length_dim = mtf.Dimension("length", 4) output_dim = mtf.Dimension("output", 2) x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32) mtf_x = mtf.import_tf_tensor( mesh, x, shape=mtf.Shape([length_dim, depth_dim])) initializer_mock = mock.MagicMock() initializer_mock.side_effect = initialize_by_shape({ (1, 3, 2, 2): [[[[1, -1], [0, 0]], [[2, -2], [-1, 1]], [[3, -3], [-2, 2]]]], }) mtf_output = mtf.layers.conv1d( mtf_x, output_dim=output_dim, filter_size=filter_size, filter_initializer=initializer_mock) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_output = lowering.export_to_tf_tensor(mtf_output) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate(actual_output) self.assertAllClose(actual, [[0, 0], [1, -1], [5, -5], [4, -4]])
def main(_): mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) # Resolve the cluster from SLURM environment cluster = tf.distribute.cluster_resolver.SlurmClusterResolver( {"mesh": mesh_shape.size // FLAGS.gpus_per_task}, port_base=8822, gpus_per_node=FLAGS.gpus_per_node, gpus_per_task=FLAGS.gpus_per_task, tasks_per_node=FLAGS.tasks_per_node) cluster_spec = cluster.cluster_spec() # Create a server for all mesh members server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id) # Only he master job takes care of the graph building, # everyone else can just chill for now if cluster.task_id > 0: server.join() # Otherwise we are the main task, let's define the devices mesh_devices = [ "/job:mesh/task:%d/device:GPU:%d" % (i, j) for i in range(cluster_spec.num_tasks("mesh")) for j in range(FLAGS.gpus_per_node) ] print("List of devices", mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") # Build the model fft_err = benchmark_model(mesh) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) # Retrieve output of computation result = lowering.export_to_tf_tensor(fft_err) with tf.Session(server.target) as sess: start = time.time() err = sess.run(result) end = time.time() time.sleep(1) start = time.time() err = sess.run(result) end = time.time() print("Max absolute FFT error %f, with wall time %f" % (err, (end - start))) time.sleep(1) exit(0)
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = model_backbone(features, labels, mesh) variables = graph._all_variables for v in variables: logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(v.name,v.shape,v.dtype.master_dtype)) mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape) # layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) mesh_shape = mtf.convert_to_shape(mesh_shape) estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss]) optimizer = layout_optimizer.LayoutOptimizer(estimator,scheduler_alg="NAIVE") layout_rules = mtf.convert_to_layout_rules(optimizer.solve()) # layout_rules=[('batch', 'b1')] logger.info("[auto mtf search] strategy: {}".format(layout_rules)) mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.SgdOptimizer(0.01) # optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) accuracy = tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'}) # profiling_hook = tf.estimator.ProfilerHook(save_steps=20, output_dir='./profiling/') # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook,logging_hook])
def Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr): lowering = mtf.Lowering(graph, mesh_to_impl) out_tsr = lowering.export_to_tf_tensor(mtf_out_tsr) assert_op = tf.assert_equal(in_tsr, out_tsr) func_name = inspect.stack()[1].function print(f'Running test {func_name}') with tf.Session() as sess: sess.run(assert_op) print(f'Test {func_name} successful\n')
def test_ids_to_embedding_correctlyEmbeds(self): seq_len = 5 vocab_size = 5 model_size = 2 gate_embedding_size = 1 frequent_token_fraction = 0.4 vocab_dim = mtf.Dimension('vocab', vocab_size) model_dim = mtf.Dimension('model', model_size) length_dim = mtf.Dimension('length', seq_len) context = mock.MagicMock() context.train = False ids = tf.constant([0, 1, 2, 3, 4], dtype=tf.int32) mtf_ids = mtf.import_tf_tensor( self.mesh, ids, shape=mtf.Shape([length_dim])) self.initializer_mock.side_effect = initialize_by_shape({ # Embedding weights. (5, 2): list(range(10)), # Context weights. (4, 2, 2): list(range(16)), # Prior weights. (3, 1, 2): list(range(6)), # Prior vocab vector. (2, 1): list(range(2)), # Prior gates vector. (3, 2): list(range(6)), # Prior bias. (2, 3): list(range(6)), }) vocab_embedding = vocab_embeddings.Mixtape( self.mesh, vocab_dim, output_dim=model_dim, variable_dtype=self.variable_dtype, name='embedding', ensemble_dim=None, gate_embedding_size=gate_embedding_size, frequent_token_fraction=frequent_token_fraction) mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids, context=None) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=['']) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) actual_embedding = lowering.export_to_tf_tensor(mtf_embedding) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate([actual_embedding])[0] self.assertAllClose(actual, np.reshape(list(range(10)), (5, 2)))
def test_hidden_to_logits_computesLogitsCorrectly(self): seq_len = 4 vocab_size = 5 model_size = 2 vocab_dim = mtf.Dimension('vocab', vocab_size) model_dim = mtf.Dimension('model', model_size) length_dim = mtf.Dimension('length', seq_len) embeddings = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32) mtf_embeddings = mtf.import_tf_tensor(self.mesh, embeddings, shape=mtf.Shape( [length_dim, model_dim])) self.initializer_mock.side_effect = initialize_by_shape({ (2, 2): [[0, 1], [2, 0]], (3, 1): [[1], [2], [3]], (1, 2): [[1], [2]], }) vocab_embedding = vocab_embeddings.AdaptiveVocabEmbedding( self.mesh, vocab_dim, output_dim=model_dim, variable_dtype=self.variable_dtype, name='embedding', ensemble_dim=None, clusters=[{ 'token_count': 2, 'embedding_size': 2 }, { 'token_count': 3, 'embedding_size': 1 }]) mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings, context=None) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=['']) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) actual_logits = lowering.export_to_tf_tensor(mtf_logits) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate([actual_logits])[0] self.assertAllClose( actual, model_size**-0.5 * np.array([[0, 2, 1, 2, 3], [1, 0, 2, 4, 6], [1, 2, 3, 6, 9], [1, 4, 4, 8, 12]]))
def test_entmax(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") length = mtf.Dimension("tensor_length", 8) tensor = mtf.range(mesh, length, tf.float32) output = entmax(tensor) grad = mtf.gradients([output], [tensor])[0] sample = sample_categorical(output, length) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) sample = lowering.export_to_tf_tensor(sample) grad = lowering.export_to_tf_tensor(grad)
def main(_): num_tasks = int(os.environ['SLURM_NTASKS']) print('num_tasks : ', num_tasks) # Resolve the cluster from SLURM environment cluster = tf.distribute.cluster_resolver.SlurmClusterResolver({"mesh": num_tasks}, port_base=8822, gpus_per_node=FLAGS.gpus_per_node, gpus_per_task=FLAGS.gpus_per_task, tasks_per_node=FLAGS.tasks_per_node) cluster_spec = cluster.cluster_spec() print(cluster_spec) # Create a server for all mesh members server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id) print(server) if cluster.task_id >0: server.join() # Otherwise we are the main task, let's define the devices devices = ["/job:mesh/task:%d/device:GPU:%d"%(i,j) for i in range(cluster_spec.num_tasks("mesh")) for j in range(FLAGS.gpus_per_task)] print("List of devices", devices) # Defines the mesh structure mesh_shape = [("row", 4), ("col", 2)] layout_rules = [("nx_block","row"), ("ny_block","col")] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, devices) # Create computational graphs net = model_fn(nc=FLAGS.nc, batch_size=FLAGS.batch_size) # Lower mesh computation graph = net.graph mesh = net.mesh lowering = mtf.Lowering(graph, {mesh:mesh_impl}) # Retrieve output of computation result = lowering.export_to_tf_tensor(net) # Perform some last processing in normal tensorflow out = tf.reduce_mean(result) with tf.Session(server.target) as sess: r = sess.run(out) print("output of computation", r) exit(0)
def testTopK(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 6) b_dim = mtf.Dimension("b", 2) inputs = tf.constant([[1, 10], [2, 9], [3, 8], [4, 7], [5, 6], [6, 5]], dtype=tf.float32) k_dim = mtf.Dimension("k", 2) d_values = tf.constant([[11, 12], [13, 14]], dtype=tf.float32) reduced_dim = a_dim expected_values = tf.constant([[6, 5], [10, 9]], dtype=tf.float32) expected_indices = tf.constant([[5, 4], [0, 1]]) expected_d_inputs = tf.constant([[0, 13], [0, 14], [0, 0], [0, 0], [12, 0], [11, 0]], dtype=tf.float32) mtf_inputs = mtf.import_fully_replicated( mesh, inputs, shape=mtf.Shape([a_dim, b_dim])) mtf_d_values = mtf.import_tf_tensor( mesh, d_values, shape=mtf.Shape([b_dim, k_dim])) mtf_values, mtf_indices = mtf.top_k(mtf_inputs, reduced_dim=reduced_dim, k_dim=k_dim, name="test_nth_smallest") [mtf_d_inputs] = mtf.gradients([mtf_values], [mtf_inputs], [mtf_d_values]) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="rows:2,cols:2", layout="a:rows,b:cols", devices=["", "", "", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_values = lowering.export_to_tf_tensor(mtf_values) actual_indices = lowering.export_to_tf_tensor(mtf_indices) actual_d_inputs = lowering.export_to_tf_tensor(mtf_d_inputs) actual_inputs = lowering.export_to_tf_tensor(mtf_inputs) self.assertAllEqual(self.evaluate(actual_inputs), self.evaluate(inputs)) self.assertAllEqual(self.evaluate(actual_values), self.evaluate(expected_values)) self.assertAllEqual(self.evaluate(actual_indices), self.evaluate(expected_indices)) self.assertAllEqual(self.evaluate(actual_d_inputs), self.evaluate(expected_d_inputs))
def test_ids_to_embedding_correctlyEmbeds(self): seq_len = 6 vocab_size = 5 model_size = 2 vocab_dim = mtf.Dimension('vocab', vocab_size) model_dim = mtf.Dimension('model', model_size) length_dim = mtf.Dimension('length', seq_len) ids = tf.constant([0, 1, 2, 3, 4, 0], dtype=tf.int32) mtf_ids = mtf.import_tf_tensor(self.mesh, ids, shape=mtf.Shape([length_dim])) self.initializer_mock.side_effect = initialize_by_shape({ (3, 2): [[0, 1], [2, 0], [-1000, -4000]], (3, 1): [[1], [2], [3]], (1, 2): [[1], [2]], }) vocab_embedding = adaptive_softmax.AdaptiveSoftmaxVocabEmbedding( self.mesh, vocab_dim, output_dim=model_dim, variable_dtype=self.variable_dtype, name='embedding', ensemble_dim=None, clusters=[{ 'token_count': 2, 'embedding_size': 2 }, { 'token_count': 3, 'embedding_size': 1 }]) mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids, context=None) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=['']) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) actual_embedding = lowering.export_to_tf_tensor(mtf_embedding) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual, = self.evaluate([actual_embedding]) self.assertAllClose(actual, [[0, 1], [2, 0], [1, 2], [2, 4], [3, 6], [0, 1]])
def test_ids_to_embedding_correctlyEmbeds(self): seq_len = 4 vocab_size = 4 model_size = 3 num_softmaxes = 1 vocab_dim = mtf.Dimension('vocab', vocab_size) model_dim = mtf.Dimension('model', model_size) length_dim = mtf.Dimension('length', seq_len) ids = tf.constant([0, 1, 2, 3], dtype=tf.int32) mtf_ids = mtf.import_tf_tensor(self.mesh, ids, shape=mtf.Shape([length_dim])) self.initializer_mock.side_effect = initialize_by_shape({ # Embedding weights. (4, 3): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 2]], # Mixture weights. (1, 3): [[1, 0, 0]], # Context weights (1, 3, 3): [ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], ], }) vocab_embedding = vocab_embeddings.MixtureOfSoftmaxes( self.mesh, vocab_dim, output_dim=model_dim, variable_dtype=self.variable_dtype, name='embedding', ensemble_dim=None, num_softmaxes=num_softmaxes) mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=['']) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) actual_embedding = lowering.export_to_tf_tensor(mtf_embedding) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate([actual_embedding])[0] self.assertAllClose(actual, [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 2]])
def main(_): #layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)] layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"), ("ny", "col"), ("ty", "row"), ("tz", "col"), ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"), ("ny_block", "col")] mesh_impl = HvdSimdMeshImpl(mtf.convert_to_shape(mesh_shape), mtf.convert_to_layout_rules(layout_rules)) # Build the model # Create computational graphs and some initializations graph = mtf.Graph() mesh = mtf.Mesh(graph, "nbody_mesh") initial_conditions, mesh_final_field = lpt_prototype( mesh, bs=FLAGS.box_size, nc=FLAGS.nc, batch_size=FLAGS.batch_size) # Lower mesh computation lowering = mtf.Lowering(graph, {mesh: mesh_impl}) # Retrieve output of computation initc = lowering.export_to_tf_tensor(initial_conditions) result = lowering.export_to_tf_tensor(mesh_final_field) with tf.Session() as sess: start = time.time() a, c = sess.run([initc, result]) end = time.time() ttime = (end - start) print('Time for ', mesh_shape, ' is : ', ttime) if comm.rank == 0: plt.figure(figsize=(9, 3)) plt.subplot(121) plt.imshow(a[0].sum(axis=2)) plt.title('Initial Conditions') plt.subplot(122) plt.imshow(c[0].sum(axis=2)) plt.title('Mesh TensorFlow') plt.colorbar() plt.savefig("mesh_nbody_%d-row:%d-col:%d.png" % (FLAGS.nc, FLAGS.nx, FLAGS.ny)) plt.close() exit(0)
def testBatchNorm(self): batch = 2 channels = 3 inputs = tf.constant([[0, 1, 2], [4, 5, 6]], dtype=np.float32) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs_0, _ = mtf.layers.batch_norm(mtf_inputs, is_training=True, momentum=0.95, epsilon=1e-6, dims_idx_start=0, dims_idx_end=1, name="bn0") mtf_outputs_1, _ = mtf.layers.batch_norm(mtf_outputs_0 * 2 + 1, is_training=True, momentum=0.95, epsilon=1e-6, dims_idx_start=0, dims_idx_end=1, name="bn1") mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs_0 = lowering.export_to_tf_tensor(mtf_outputs_0) actual_outputs_1 = lowering.export_to_tf_tensor(mtf_outputs_1) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) [actual_0, actual_1] = self.evaluate([actual_outputs_0, actual_outputs_1]) expected = np.array([[-1, -1, -1], [1, 1, 1]]) self.assertAllClose(actual_0, expected) self.assertAllClose(actual_1, expected)
def testDotProductAttention(self, batch, heads, length_q, length_kv, depth_k, depth_v): query = tf.random_normal([batch, heads, length_q, depth_k]) key = tf.random_normal([batch, heads, length_kv, depth_k]) value = tf.random_normal([batch, heads, length_kv, depth_v]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) heads_dim = mtf.Dimension("heads", heads) length_q_dim = mtf.Dimension("length_q", length_q) length_kv_dim = mtf.Dimension("length_kv", length_kv) depth_k_dim = mtf.Dimension("depth_k", depth_k) depth_v_dim = mtf.Dimension("depth_v", depth_v) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, heads_dim, length_q_dim, depth_k_dim])) mtf_key = mtf.import_tf_tensor( mesh, key, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_k_dim])) mtf_value = mtf.import_tf_tensor( mesh, value, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_v_dim])) mtf_outputs = mtf.layers.dot_product_attention(mtf_query, mtf_key, mtf_value, mask=None, is_training=False) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))