def create_graph_mesh_and_mesh_impl(self): """Creates mtf graph, mesh, and mesh impl. This function can be called inside model_fn, which might be tpu_rewritten. Returns: graph, mesh, mesh_impl """ if self._use_tpu: assert self._d_assignment graph = mtf.Graph() # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * self._num_hosts devices_memory_usage = [worker0_mem] + [0] * (self._num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(self._cpu_devices, devices_memory_usage) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( self._mesh_shape, self._layout_rules, None, self._d_assignment) return graph, mesh, mesh_impl else: graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh', None) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( self._mesh_shape, self._layout_rules, self._gpu_devices) return graph, mesh, mesh_impl
def Replication2(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([1, 4]), \ mesh1:GetMeshImpl([2, 4])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithDuplicates(mtf_in_tsr, mesh1) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Split3(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([2, 2], [0, 2, 4, 6]), \ mesh1:GetMeshImpl([2, 4])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape[:2] + [('axis1', shape[2])] + shape[3:]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1, mtf_shape.dimension_names) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def NoConcatSplit(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \ mesh1:GetMeshImpl([4, 2])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape([shape[0], ('axis0', shape[1])] + shape[2:]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1, mtf_shape.dimension_names) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Replication5(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([2, 1], [0, 4]), \ mesh1:GetMeshImpl([2, 4])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape([('axis0', shape[0])] + shape[1:]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithDuplicates(mtf_in_tsr, mesh1, mtf_shape.dimension_names) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Transpose1(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0: GetMeshImpl([4, 2]), mesh1: GetMeshImpl([4, 2])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape([('axis0', shape[0]), ('axis1', shape[1]), *shape[2:]]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), RandName(), 'axis0', 'axis1']) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Contract2(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([2, 4]), \ mesh1:GetMeshImpl([4, 2])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), 'axis0', 'axis1', RandName()]) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def MoreDevices(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([2]), \ mesh1:GetMeshImpl([8])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape[:-1] + [('axis0', shape[-1])]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), 'axis0', RandName(), RandName()]) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=0.2) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] return tf.group(tf_update_ops)
def testLayout(self): # Construct a Mesh TensorFlow graph and mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") # Decide on a mesh shape. mesh_shape = mtf.convert_to_shape("m1:4,m2:2") # Compute a layout based on the graph and mesh. # Note that knowing the identity of the outputs is important to the # optimization since they cannot be freed. layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone( layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
def testMinimizePeakMemoryList_SingleUseTensor(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:4'), dtype=tf.int32, name='X') y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:3'), dtype=tf.int32, name='Y').outputs[0] mtf.BroadcastOperation(y, mtf.convert_to_shape('b:3,c:2'), name='Z') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('X:0') graph.set_tensor_final('Z:0') schedule = list(scheduler.minimize_peak_memory(graph, 'LIST')) # When nothing is scheduled: # X frees -4 entries # Y frees -3 entries # After [Y] scheduled: # X frees -4 entries # Z frees -3 entries # Hence the schedule should be [Y, Z, X]. self.assertEqual(schedule, [1, 2, 0])
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'num_heads:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) batch_dim = mtf.Dimension('batch', batch_size) seq_dim = mtf.Dimension('seq', seq_length) input_ids = tf.random.uniform((batch_size, seq_length), minval=0, maxval=vocab_size, dtype=tf.int32) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) model = bert_lib.BertModel(config=bert_config, is_training=True, input_ids=mtf_input_ids, input_mask=None, token_type_ids=None) pooled = model.get_pooled_output() lowering = mtf.Lowering(graph, {mesh: mesh_impl}) return lowering.export_to_tf_tensor(pooled)
def testNthSmallestReduceSecondDim(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 6) b_dim = mtf.Dimension("b", 2) inputs = tf.constant([[1, 10], [2, 9], [3, 8], [4, 7], [5, 6], [6, 5]]) n = 0 # find smallest element (n is zero-indexed) reduced_dim = b_dim expected_outputs = tf.constant([1, 2, 3, 4, 5, 5]) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([a_dim, b_dim])) mtf_outputs = mtf.nth_smallest_element( mtf_inputs, n, reduced_dim, "test_nth_smallest") mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="all:2", layout="a:all", devices=["", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) self.assertAllEqual(self.evaluate(actual_outputs), self.evaluate(expected_outputs))
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) lr, update_ops = optimization_lib.create_optimizer( loss, 0.2, 100, 10) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append( tf.assign_add(tf.train.get_or_create_global_step(), 1)) train_op = tf.group(tf_update_ops) return lr, train_op
def testLayoutAndMeshShape(self): # Same as previous test, but don't specify a 4x2 mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(mtf_graph, 8, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0) self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 4), mtf.Dimension("mesh_1", 2)]) layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape( mtf_graph, 8, [z], 1) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape)) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape)) self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 8)])
def testDenseReluDense(self): batch = 2 channels = 3 hidden = 5 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) hidden_dim = mtf.Dimension("hidden", hidden) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.dense_relu_dense(mtf_inputs, hidden_channels=hidden_dim, is_training=False) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, inputs.shape)
def testCorr2DInput(self): batch = 4 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tfp.stats.correlation(inputs, sample_axis=0, event_axis=1) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape) self.assertAllClose(actual, expected)
def testGraph(self): graph = mtf.Graph() self.assertLen(graph.operations, 0) self.assertLen(graph.tensors, 0) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertLen(graph.tensors, 1) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.tensors, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.tensors, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2)
def testWeightsNonzero(self): inputs = tf.constant([[3, 1, 0], [1, 0, 0]]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0]) channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.cast(tf.not_equal(inputs, 0), tf.float32) tf_group = lowering.copy_masters_to_slices() self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllEqual(actual, expected)
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels del features mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") with mtf.utils.outside_all_rewrites(): fsum = benchmark_model(mesh) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_err = tf.to_float(lowering.export_to_tf_tensor(fsum)) with mtf.utils.outside_all_rewrites(): return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err)
def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels, heads, window_size): length_q = length query = tf.random_normal([batch, length_q, io_channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_q_dim = mtf.Dimension("length_q", length_q) io_channels_dim = mtf.Dimension("io_channels", io_channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim])) mtf_outputs = mtf.layers.masked_local_attention_1d( mtf_query, kv_channels=kv_channels_dim, heads=heads_dim, window_size=window_size) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, (batch, length_q, io_channels))
def testDynamicText2self_unpacked(self): batch = 2 length = 5 input_tensors = { "inputs": [[3, 1, 4, 1, 0], [1, 4, 3, 2, 1]], "targets": [[1, 1, 0, 0, 0], [9, 8, 1, 2, 1]], } expected_output_tensors = { "targets": [[3, 1, 4, 1, 1, 1, 0, 0, 0, 0], [1, 4, 3, 2, 1, 9, 8, 1, 2, 1]], } graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) input_shape = mtf.Shape([batch_dim, length_dim]) mtf_features = { k: mtf.import_tf_tensor(mesh, v, input_shape) for k, v in input_tensors.items() } mtf_outputs = utils._dynamic_text2self(mtf_features) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) for k, v in expected_output_tensors.items(): out = lowering.export_to_tf_tensor(mtf_outputs[k]) actual = self.evaluate(out) self.assertAllEqual(actual, v)
def testMinimizePeakMemoryList(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:3,b:4'), dtype=tf.int32, name='X').outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:4,c:5'), dtype=tf.int32, name='Y').outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,b:4,c:5'), name='Z') w = mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='W').outputs[0] mtf.BroadcastOperation(w, mtf.convert_to_shape('a:3,b:4,c:5'), name='V') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('Z:0') graph.set_tensor_final('V:0') schedule = list(scheduler.minimize_peak_memory(graph, 'LIST')) # List Scheduler prefers to schedule things that free the most memory. # When nothing is scheduled: # X frees -12 entries. # Y frees -20 entries. # After [X] scheduled: # Y frees -20 entries. # After [X, Y] scheduled: # Z frees -60 entries. # W frees -15 entries. # After [X, Y, W] scheduled: # Z frees -28 entries. # V frees -45 entries. # Hence the schedule should be [X, Y, W, Z, V]. self.assertEqual(schedule, [0, 1, 3, 2, 4])
def testDense(self, units, use_bias): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) depth_dim = mtf.Dimension("depth", units) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf.layers.dense(mtf_inputs, output_dim=depth_dim, reduced_dims=[channels_dim], activation=mtf.relu, use_bias=use_bias) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.keras.layers.Dense(units=units, activation=tf.nn.relu, use_bias=use_bias)(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def testMultiheadAttention(self, kv_channels, heads): batch = 2 length = 8 channels = 3 query = tf.random_normal([batch, length, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) channels_dim = mtf.Dimension("channels", channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_dim, channels_dim])) mtf_outputs = mtf.layers.multihead_attention( mtf_query, memory_antecedent=None, mask=None, kv_channels=kv_channels_dim, heads=heads_dim) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, query.shape)
def testRecomputeGrad(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") # let's differentiate x^2 + x # dy/dx = 2x+1 def x_squared_plus_x(x): return x * x + x x = tf.constant([5, 10], dtype=tf.float32) dy = tf.constant([2, 3], dtype=tf.float32) two = mtf.Dimension("two", 2) expected_y = tf.constant([30, 110], dtype=tf.float32) expected_dx = tf.constant([22, 63], dtype=tf.float32) mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two])) mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two])) mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x]) [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy]) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="processors:2", layout="two:processors", devices=["", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_y = lowering.export_to_tf_tensor(mtf_y) actual_dx = lowering.export_to_tf_tensor(mtf_dx) self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y)) self.assertAllEqual(self.evaluate(actual_dx), self.evaluate(expected_dx))
def testConv1d(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") filter_size = 3 depth_dim = mtf.Dimension("depth", 2) length_dim = mtf.Dimension("length", 4) output_dim = mtf.Dimension("output", 2) x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32) mtf_x = mtf.import_tf_tensor( mesh, x, shape=mtf.Shape([length_dim, depth_dim])) initializer_mock = mock.MagicMock() initializer_mock.side_effect = initialize_by_shape({ (1, 3, 2, 2): [[[[1, -1], [0, 0]], [[2, -2], [-1, 1]], [[3, -3], [-2, 2]]]], }) mtf_output = mtf.layers.conv1d( mtf_x, output_dim=output_dim, filter_size=filter_size, filter_initializer=initializer_mock) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_output = lowering.export_to_tf_tensor(mtf_output) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate(actual_output) self.assertAllClose(actual, [[0, 0], [1, -1], [5, -5], [4, -4]])
def main(_): mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) # Resolve the cluster from SLURM environment cluster = tf.distribute.cluster_resolver.SlurmClusterResolver( {"mesh": mesh_shape.size // FLAGS.gpus_per_task}, port_base=8822, gpus_per_node=FLAGS.gpus_per_node, gpus_per_task=FLAGS.gpus_per_task, tasks_per_node=FLAGS.tasks_per_node) cluster_spec = cluster.cluster_spec() # Create a server for all mesh members server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id) # Only he master job takes care of the graph building, # everyone else can just chill for now if cluster.task_id > 0: server.join() # Otherwise we are the main task, let's define the devices mesh_devices = [ "/job:mesh/task:%d/device:GPU:%d" % (i, j) for i in range(cluster_spec.num_tasks("mesh")) for j in range(FLAGS.gpus_per_node) ] print("List of devices", mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") # Build the model fft_err = benchmark_model(mesh) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) # Retrieve output of computation result = lowering.export_to_tf_tensor(fft_err) with tf.Session(server.target) as sess: start = time.time() err = sess.run(result) end = time.time() time.sleep(1) start = time.time() err = sess.run(result) end = time.time() print("Max absolute FFT error %f, with wall time %f" % (err, (end - start))) time.sleep(1) exit(0)
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = model_backbone(features, labels, mesh) variables = graph._all_variables for v in variables: logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(v.name,v.shape,v.dtype.master_dtype)) mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape) # layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) mesh_shape = mtf.convert_to_shape(mesh_shape) estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss]) optimizer = layout_optimizer.LayoutOptimizer(estimator,scheduler_alg="NAIVE") layout_rules = mtf.convert_to_layout_rules(optimizer.solve()) # layout_rules=[('batch', 'b1')] logger.info("[auto mtf search] strategy: {}".format(layout_rules)) mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.SgdOptimizer(0.01) # optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) accuracy = tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'}) # profiling_hook = tf.estimator.ProfilerHook(save_steps=20, output_dir='./profiling/') # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook,logging_hook])
def get_placement_mesh(hparams): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, hparams.layout, mesh_devices) return mesh, mesh_impl