def testReturnsTopoSort(self, scheduler_alg): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') x = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:3,b:4'), dtype=tf.int32, name='X').outputs[0] y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:4,c:5'), dtype=tf.int32, name='Y').outputs[0] mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z1') mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z2') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('Z1:0') graph.set_tensor_final('Z2:0') schedule = list(scheduler.minimize_peak_memory(graph, scheduler_alg)) self.assertCountEqual(schedule[0:2], [0, 1]) self.assertCountEqual(schedule[2:4], [2, 3])
def testMinimizePeakMemoryList_SingleUseTensor(self): mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, 'my_mesh') mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:4'), dtype=tf.int32, name='X') y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:3'), dtype=tf.int32, name='Y').outputs[0] mtf.BroadcastOperation(y, mtf.convert_to_shape('b:3,c:2'), name='Z') graph = graph_interface.GraphInterface(mtf_graph) graph.set_tensor_final('X:0') graph.set_tensor_final('Z:0') schedule = list(scheduler.minimize_peak_memory(graph, 'LIST')) # When nothing is scheduled: # X frees -4 entries # Y frees -3 entries # After [Y] scheduled: # X frees -4 entries # Z frees -3 entries # Hence the schedule should be [Y, Z, X]. self.assertEqual(schedule, [1, 2, 0])
def setUp(self): super(OperationSplittabilityTest, self).setUp() self.graph = mtf.Graph() self.mesh = mtf.Mesh(self.graph, "my_mesh") self.a_dim = mtf.Dimension("a", 5) self.b_dim = mtf.Dimension("b", 10) self.c_dim = mtf.Dimension("c", 15) self.ab_shape = mtf.Shape([self.a_dim, self.b_dim]) self.x = mtf.zeros(self.mesh, self.ab_shape) self.batch_dim = mtf.Dimension("batch", 100) self.grid_h_dim = mtf.Dimension("grid_h", 10) self.grid_w_dim = mtf.Dimension("grid_w", 10) self.filter_h_dim = mtf.Dimension("filter_h", 5) self.filter_w_dim = mtf.Dimension("filter_w", 5) self.in_dim = mtf.Dimension("in", 10) self.out_dim = mtf.Dimension("out", 10) self.image = mtf.zeros(self.mesh, [self.batch_dim, self.grid_h_dim, self.grid_w_dim, self.in_dim])
def setUp(self): super(LayoutValidatorTest, self).setUp() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 5) b_dim = mtf.Dimension("b", 10) concat_dim1 = mtf.Dimension("concat", 15) concat_dim2 = mtf.Dimension("concat", 20) x1 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim1])) x2 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim2])) mtf.ConcatOperation([x1, x2], "concat") # We add a tensor with anonymous shape, which is supposed to be # unsplittable (i.e. none of its dimensions show up during # test_SplittableMtfDimensionNames). _ = mtf.zeros(mesh, mtf.anonymous_shape(mtf.Shape([a_dim, b_dim]))) mesh_shape = mtf.Shape([("m1", 4), ("m2", 2)]) self.valid_layouts = valid_layouts.LayoutValidator(graph, mesh_shape)
def WrongShape(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \ mesh1:GetMeshImpl([8])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape[:-2] + [('axis0', shape[2]), ('axis0', shape[3])]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), 'axis0', RandName(), RandName()]) try: Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr) assert False # This test should fail with ValueError except ValueError: return
def test_variable_placer(self): sizes = [100, 0, 0, 0] device_list = ['cpu:0', 'cpu:1', 'cpu:2', 'cpu:3'] with tf.Graph().as_default() as g: var_placer = mtf.utils.BalancedVariablePlacer(device_list, sizes) graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh', var_placer) hidden_dim = mtf.Dimension('hidden', 10) output_dim = mtf.Dimension('output_feature', 10) for i in xrange(5): # Each variable takes 400 Bytes, and will be placed from cpu:1. mtf.get_variable(mesh, 'w{}'.format(i), [hidden_dim, output_dim]) for i in xrange(5): var = g.get_tensor_by_name('w{}:0'.format(i)) device = (i + 1) % len(device_list) self.assertEqual('cpu:{}'.format(device), var.device)
def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels, heads, block_length): length_q = length length_m = length query = tf.random_normal([batch, length_q, io_channels]) memory = tf.random_normal([batch, length_m, io_channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_q_dim = mtf.Dimension("length_q", length_q) length_m_dim = mtf.Dimension("length_m", length_m) io_channels_dim = mtf.Dimension("io_channels", io_channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim])) mtf_memory = mtf.import_tf_tensor( mesh, memory, shape=mtf.Shape([batch_dim, length_m_dim, io_channels_dim])) mtf_outputs = mtf.layers.masked_local_attention_1d( mtf_query, mtf_memory, kv_channels=kv_channels_dim, heads=heads_dim, block_length=block_length) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, (batch, length_q, io_channels))
def main(_): # Creating layout and mesh implementation mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)] layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"), ("ny", "col"), ("ty", "row"), ("tz", "col"), ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"), ("ny_block", "col")] mesh_impl = HvdSimdMeshImpl( mtf.convert_to_shape(mesh_shape), mtf.convert_to_layout_rules(layout_rules)) # Create the graph and mesh graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") ## Load initial power spectrum klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] # Defines the computational graph for the nbody initial_conditions, final_field = nbody_fn(mesh, klin, plin) # Lower mesh computation lowering = mtf.Lowering(graph, {mesh: mesh_impl}) # Retrieve fields as tf tensors tf_initc = lowering.export_to_tf_tensor(initial_conditions) tf_final = lowering.export_to_tf_tensor(final_field) with tf.Session() as sess: start = time.time() init_conds, final = sess.run([tf_initc, tf_final]) end = time.time() print('\n Time for the mesh run : %f \n' % (end - start)) # Export these fields np.save('simulation_output_%d.npy' % comm.Get_rank(), final) np.save('simulation_input_%d.npy' % comm.Get_rank(), init_conds) exit(0)
def testWeightsNonzero(self): inputs = tf.constant([[3, 1, 0], [1, 0, 0]]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0]) channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1]) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = common_layers.weights_nonzero(inputs) tf_group = lowering.copy_masters_to_slices() self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllEqual(actual, expected)
def testNthSmallestReduceSecondDim(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 6) b_dim = mtf.Dimension("b", 2) inputs = tf.constant([[1, 10], [2, 9], [3, 8], [4, 7], [5, 6], [6, 5]]) n = 0 # find smallest element (n is zero-indexed) reduced_dim = b_dim expected_outputs = tf.constant([1, 2, 3, 4, 5, 5]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape([a_dim, b_dim])) mtf_outputs = mtf.nth_smallest_element(mtf_inputs, n, reduced_dim, "test_nth_smallest") mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape="all:2", layout="a:all", devices=["", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) self.assertAllEqual(self.evaluate(actual_outputs), self.evaluate(expected_outputs))
def model_fn(nc=64, batch_size=1): """ Example of function implementing a CNN and returning a value. """ # Create the mesh TF graph graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions n_block_x = 4 n_block_y = 2 n_block_z = 1 batch_dim = mtf.Dimension("batch", batch_size`) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc//n_block_x) sy_dim = mtf.Dimension('sy_block', nc//n_block_y) sz_dim = mtf.Dimension('sz_block', nc//n_block_z) image_c_dim = mtf.Dimension('image_c', 3) hidden_dim = mtf.Dimension('h', 128) # Create some input data data = mtf.random_uniform(mesh, [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim, image_c_dim]) net = mtf.layers.conv3d_with_blocks(data, hidden_dim, filter_size=(3, 3, 3), strides=(1, 1, 1), padding='SAME', d_blocks_dim=nx_dim, h_blocks_dim=ny_dim) net = mtf.reduce_sum(net, output_shape=[batch_dim, hidden_dim] ) return net
def testBatchNorm(self): batch = 2 channels = 3 inputs = tf.constant([[0, 1, 2], [4, 5, 6]], dtype=np.float32) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs_0, _ = mtf.layers.batch_norm( mtf_inputs, is_training=True, momentum=0.95, epsilon=1e-6, dims_idx_start=0, dims_idx_end=1, name="bn0") mtf_outputs_1, _ = mtf.layers.batch_norm( mtf_outputs_0 * 2 + 1, is_training=True, momentum=0.95, epsilon=1e-6, dims_idx_start=0, dims_idx_end=1, name="bn1") mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs_0 = lowering.export_to_tf_tensor(mtf_outputs_0) actual_outputs_1 = lowering.export_to_tf_tensor(mtf_outputs_1) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) [actual_0, actual_1] = self.evaluate([actual_outputs_0, actual_outputs_1]) expected = np.array([[-1, -1, -1], [1, 1, 1]]) self.assertAllClose(actual_0, expected) self.assertAllClose(actual_1, expected)
def testMultiheadAttention(self, kv_channels, heads): batch = 2 length = 8 channels = 3 query = tf.random_normal([batch, length, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) channels_dim = mtf.Dimension("channels", channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_dim, channels_dim])) mtf_outputs = mtf.layers.multihead_attention( mtf_query, memory_antecedent=None, mask=None, kv_channels=kv_channels_dim, heads=heads_dim, is_training=False) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, query.shape)
def testLayoutAndMeshShape(self): # Same as previous test, but don't specify a 4x2 mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape( mtf_graph, 8, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone( layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0) self.assertCountEqual( mesh_shape.dims, [mtf.Dimension("mesh_0", 4), mtf.Dimension("mesh_1", 2)]) layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape( mtf_graph, 8, [z], 1) self.assertIsNone( layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape)) self.assertIsNone( layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertIsNone( layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape)) self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 8)])
def test_model(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") seq_len = params["n_ctx"] batch_dim = mtf.Dimension("batch", 1) sequence_dim = mtf.Dimension("sequence", seq_len) features = { 'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32), 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) } # create mask num_mem_kv = params.get('num_mem_kv', 0) length_dim = mtf.Dimension('sequence', seq_len) memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv) embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len) embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) other_features = {} variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim with not_raises(Exception): logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) logits = lowering.export_to_tf_tensor(logits)
def testDense(self, units, use_bias, new_dim_name): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) new_dim = mtf.Dimension(new_dim_name, units) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.dense( mtf_inputs, new_dims=new_dim, reduced_dims=[channels_dim], activation=mtf.relu, use_bias=use_bias, ) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.keras.layers.Dense(units=units, activation=tf.nn.relu, use_bias=use_bias)(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def testDynamicText2self(self): batch = 2 length = 5 input_tensors = { "inputs": [[3, 1, 4, 1, 0], [1, 4, 3, 2, 1]], "inputs_segmentation": [[1, 1, 2, 2, 0], [1, 2, 2, 2, 2]], "inputs_position": [[0, 1, 0, 1, 0], [0, 0, 1, 2, 3]], "targets": [[1, 1, 0, 0, 0], [9, 8, 1, 2, 1]], "targets_segmentation": [[1, 2, 0, 0, 0], [1, 1, 1, 2, 2]], "targets_position": [[0, 0, 0, 0, 0], [0, 1, 2, 0, 1]] } expected_output_tensors = { "targets": [[3, 1, 1, 4, 1, 1, 0, 0, 0, 0], [1, 9, 8, 1, 4, 3, 2, 1, 2, 1]], "targets_segmentation": [[1, 1, 1, 2, 2, 2, 0, 0, 0, 0], [1, 1, 1, 1, 2, 2, 2, 2, 2, 2]], "targets_position": [[0, 1, 2, 0, 1, 2, 0, 0, 0, 0], [0, 1, 2, 3, 0, 1, 2, 3, 4, 5]] } graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) input_shape = mtf.Shape([batch_dim, length_dim]) mtf_features = { k: mtf.import_tf_tensor(mesh, v, input_shape) for k, v in input_tensors.items() } mtf_outputs = utils._dynamic_text2self(mtf_features) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) for k, v in expected_output_tensors.items(): out = lowering.export_to_tf_tensor(mtf_outputs[k]) actual = self.evaluate(out) self.assertAllEqual(actual, v)
def test_sampling(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", 1) sequence_dim = mtf.Dimension("sequence", 1) inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) inputs = mtf.pad(inputs, [0, 3], sequence_dim.name) # create mask seq_len = params["n_ctx"] num_mem_kv = params.get('num_mem_kv', 0) length_dim = mtf.Dimension('sequence', seq_len) memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv) embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len) embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) other_features = {} other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32)) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim params["mode"] = "predict" with not_raises(Exception): samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(), remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) samples = lowering.export_to_tf_tensor(samples)
def testLayout(self): # Construct a Mesh TensorFlow graph and mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") # Decide on a mesh shape. mesh_shape = mtf.convert_to_shape("m1:4,m2:2") # Compute a layout based on the graph and mesh. # Note that knowing the identity of the outputs is important to the # optimization since they cannot be freed. layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
def testGraph(self): graph = mtf.Graph() self.assertEmpty(graph.operations) self.assertEmpty(graph.trainable_variables) self.assertEmpty(graph.all_variables) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertEmpty(graph.trainable_variables) self.assertEmpty(graph.all_variables) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2)
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size): graph = mtf.Graph() meshes = [] mesh_to_impl = {} assert num_gpus % num_nodes == 0 assert num_gpus % 2 == 0 gpus_per_node = num_gpus // num_nodes devices = utils.GetDeviceList(num_gpus, gpus_per_node) mesh = mtf.Mesh(graph, f'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2], devices=devices[:num_gpus//2], gpus_per_node=gpus_per_node) mesh = mtf.Mesh(graph, f'mesh1') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2], devices=devices[num_gpus//2:], gpus_per_node=gpus_per_node) mesh = mtf.Mesh(graph, f'mesh2') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], devices=utils.FlattenList(utils.TransposeLists( [devices[:num_gpus//2], devices[num_gpus//2:]])), gpus_per_node=gpus_per_node) assert len(inputs.shape) == 2 assert inputs.shape == labels.shape shape = utils.ConvertToShape([('axis0', batch_size), inputs.shape.as_list()[1]]) mtf_inputs = mtf.import_tf_tensor(meshes[2], inputs, shape) shape = shape.rename_dimension('axis0', utils.RandName()) mtf_labels = mtf.import_tf_tensor(meshes[2], labels, shape) return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def testSeparableConv1d(self, random_normal_initializer_mock): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") depth_dim = mtf.Dimension("depth", 2) length_dim = mtf.Dimension("length", 4) output_dim = mtf.Dimension("output", 2) x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32) mtf_x = mtf.import_tf_tensor(mesh, x, shape=mtf.Shape([length_dim, depth_dim])) initializer_mock = mock.MagicMock() random_normal_initializer_mock.return_value = initializer_mock initializer_mock.side_effect = initialize_by_shape({ (2, ): [1, 2], (2, 2): [[1, 0], [1, -1]], }) mtf_output = mtf.layers.separable_conv1d(mtf_x, output_dim, min_relative_pos=-1, max_relative_pos=1, use_bias=True) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_output = lowering.export_to_tf_tensor(mtf_output) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate(actual_output) self.assertAllClose(actual, [[3, -2], [6, -4], [9, -6], [7, -4]])
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels del features mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list,) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") with mtf.utils.outside_all_rewrites(): field = nbody_model(mesh) batch_dim, x_dim, y_dim, z_dim = field.shape x_dim_nosplit = mtf.Dimension("nx_nosplit", FLAGS.cube_size) y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size) # Until we implement distributed outputs, we only return one example field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size-1]) field_slice = mtf.reshape(field_slice, [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim]) #field_slice = field lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice)) with mtf.utils.outside_all_rewrites(): return tpu_estimator.TPUEstimatorSpec(mode, predictions={'field': tf_field})
def testConv1d(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") filter_size = 3 depth_dim = mtf.Dimension("depth", 2) length_dim = mtf.Dimension("length", 4) output_dim = mtf.Dimension("output", 2) x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32) mtf_x = mtf.import_tf_tensor(mesh, x, shape=mtf.Shape([length_dim, depth_dim])) initializer_mock = mock.MagicMock() initializer_mock.side_effect = initialize_by_shape({ (1, 3, 2, 2): [[[[1, -1], [0, 0]], [[2, -2], [-1, 1]], [[3, -3], [-2, 2]]]], }) mtf_output = mtf.layers.conv1d(mtf_x, output_dim=output_dim, filter_size=filter_size, filter_initializer=initializer_mock) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_output = lowering.export_to_tf_tensor(mtf_output) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate(actual_output) self.assertAllClose(actual, [[0, 0], [1, -1], [5, -5], [4, -4]])
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list, ) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list(ctx.device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) max_predictions_per_seq = masked_lm_positions.get_shape()[1].value max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq", max_predictions_per_seq) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_masked_lm_positions = mtf.import_tf_tensor( mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_ids = mtf.import_tf_tensor( mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_weights = mtf.import_tf_tensor( mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim]) mtf_next_sentence_labels = mtf.import_tf_tensor( mesh, next_sentence_labels, [batch_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = bert_lib.BertModel(config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, token_type_ids=mtf_segment_ids, layout=layout_rules, mesh_shape=mesh_shape) (masked_lm_loss, masked_lm_example_loss, masked_lm_logits) = model.get_masked_lm_output( mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_logits ) = model.get_next_sentence_output(mtf_next_sentence_labels) extra_loss = model.get_extra_loss() total_loss = masked_lm_loss + next_sentence_loss total_loss = mtf.anonymize(total_loss) masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss) masked_lm_logits = mtf.anonymize(masked_lm_logits) next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss) next_sentence_logits = mtf.anonymize(next_sentence_logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss + extra_loss, learning_rate, num_train_steps, num_warmup_steps, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_logits, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_logits = tf.reshape(masked_lm_logits, [-1, masked_lm_logits.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_logits, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_logits = tf.reshape( next_sentence_logits, [-1, next_sentence_logits.shape[-1]]) next_sentence_predictions = tf.argmax(next_sentence_logits, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss, } eval_metrics = (metric_fn, [ lowering.export_to_tf_tensor(masked_lm_example_loss), lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids, masked_lm_weights, lowering.export_to_tf_tensor(next_sentence_example_loss), lowering.export_to_tf_tensor(next_sentence_logits), next_sentence_labels ]) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None): hparams = copy.deepcopy(hparams) use_tpu = params and params.get("use_tpu", False) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() # wrapped graph named "my_mesh" mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_size = mesh_shape.size print("mesh_shape.size = ", mesh_shape.size) mesh_devices = [""] * mesh_size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.use_tpu: ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list,) # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients([loss], [v.outputs[0] for v in graph.trainable_variables]) if FLAGS.optimizer == 'Adafactor': optimizer = mtf.optimize.AdafactorOptimizer() else: assert FLAGS.optimizer == 'SGD' optimizer = mtf.optimize.SgdOptimizer(lr=FLAGS.lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {'mean_logits': mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"inputs": features, "labels": labels} sequence_length_dict = { "inputs": params["n_ctx"], "labels": params["n_ctx"] } params = add_mode_to_params(params, mode) batch_size = get_batch_size(params) batch_dim = mtf.Dimension("batch", batch_size) batch_dims = [batch_dim] feature_length = sequence_length_dict["inputs"] length_dim = mtf.Dimension("sequence", feature_length) mtf_features = {} for key, x in features_dict.items(): if x is not None: feature_shape = mtf.Shape(batch_dims + [length_dim]) if type(features_dict[key]) == dict: features_dict[key] = features_dict[key]["feature"] x = tf.cast(features_dict[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model other_features = {} memory_length_dim = mtf.Dimension("memory_length", length_dim.size) attn_bias = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None # Add attn_bias into mtf_features other_features["attn_bias"] = attn_bias # Define other Dimensions that we'll need inside the model embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) # We need this because gathering when both the args have the same dimension in them breaks things # This dim is specifically for the weights # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["inputs"] if params["remove_partial_sequences"] is None: params["remove_partial_sequences"] = False export = params.get("export", False) if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax']) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): mtf_samples, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=batch_dim, sequence_length=sequence_length_dict, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) return { "logits": logits, "loss": loss, "loss_batch": loss_batch } else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] logits = output_dict["logits"] else: # If we're not splitting into microbatches, return logits & loss as is if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Auto layout generation if params["auto_layout"]: auto_layout(graph, mesh_shape, logits, loss) if params["auto_layout_and_mesh_shape"]: auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype) # Log summaries to tensorboard mtf.scalar_summary("loss", loss) # Log gradients if in params if params["log_grads"] not in [None, False]: for g in var_grads: grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) else: # For now, we can only export fully-replicated tensors. # This has to be done before lowering or they will not be included in the graph mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) max_logits = mtf.argmax(logits, vocab_dim) del logits fully_replicated_mean_logits = mtf.anonymize(mean_logits) fully_replicated_max_logits = mtf.anonymize(max_logits) fully_replicated_loss_batch = mtf.anonymize(loss_batch) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step tf.logging.info(f"tf_update_ops: {tf_update_ops}") train_op = tf.group(tf_update_ops) else: tf_mean_logits = lowering.export_to_tf_tensor( fully_replicated_mean_logits) tf_max_logits = lowering.export_to_tf_tensor( fully_replicated_max_logits) tf_loss_batch = tf.to_float( lowering.export_to_tf_tensor(fully_replicated_loss_batch)) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # Evaluation metrics def _perplexity(loss): perplexity = tf.exp(loss) return tf.metrics.mean(perplexity) def _bits_per_byte(loss): bpb = loss * (0.29335 / math.log(2)) return tf.metrics.mean(bpb) def _metric_fn(tf_mean_logits, tf_loss_batch): mean_logits = tf.metrics.mean(tf_mean_logits) loss = tf.reduce_mean(tf_loss_batch) perp = _perplexity(loss) bpb = _bits_per_byte(loss) return { "mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb } def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): eos_token = params["eos_id"] answer_positions = tf.where( tf.math.not_equal(labels, eos_token)) correct_answers = tf.gather_nd( tf.math.equal(tf_max_logits, labels), answer_positions) accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) # I guess tf_loss_batch has z_loss and maybe other stuff added to it # so maybe this should be calculated separately in the future answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) log_perplexity = tf.metrics.mean(answer_loss) return { "lambada_acc": accuracy, "lambada_log_ppl": log_perplexity } eval_task = params["eval_task"] if eval_task == "lambada": eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) else: eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def testPool(self, pooling_method): batch = 2 depth = 3 height = 4 width = 6 channels = 3 tf.random.set_random_seed(1234) inputs = tf.random_normal([batch, depth, height, width, channels]) stride_d = 3 stride_h = 2 stride_w = 3 graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) depth_dim = mtf.Dimension("depth", depth) height_dim = mtf.Dimension("height", height) width_dim = mtf.Dimension("width", width) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape([ batch_dim, depth_dim, height_dim, width_dim, channels_dim ])) if pooling_method == "MAX_2D": mtf_outputs = mtf.layers.max_pool2d(mtf_inputs, ksize=(stride_h, stride_w)) inputs = tf.reshape(inputs, [batch * depth, height, width, channels]) expected_outputs = tf.keras.layers.MaxPooling2D( (stride_h, stride_w))(inputs) expected_outputs = tf.reshape(expected_outputs, [ batch, depth, int(height / stride_h), int(width / stride_w), channels ]) elif pooling_method == "AVG_2D": mtf_outputs = mtf.layers.avg_pool2d(mtf_inputs, ksize=(stride_h, stride_w)) inputs = tf.reshape(inputs, [batch * depth, height, width, channels]) expected_outputs = tf.keras.layers.AveragePooling2D( (stride_h, stride_w))(inputs) expected_outputs = tf.reshape(expected_outputs, [ batch, depth, int(height / stride_h), int(width / stride_w), channels ]) elif pooling_method == "MAX_3D": mtf_outputs = mtf.layers.max_pool3d( mtf_inputs, ksize=[stride_d, stride_h, stride_w]) expected_outputs = tf.keras.layers.MaxPooling3D( [stride_d, stride_h, stride_w])(inputs) elif pooling_method == "AVG_3D": mtf_outputs = mtf.layers.avg_pool3d( mtf_inputs, ksize=[stride_d, stride_h, stride_w]) expected_outputs = tf.keras.layers.AveragePooling3D( [stride_d, stride_h, stride_w])(inputs) mtf_gradient = mtf.gradients([mtf_outputs], [mtf_inputs])[0] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) actual_gradient = lowering.export_to_tf_tensor(mtf_gradient) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllClose(actual, expected) actual = self.evaluate(actual_gradient) if pooling_method == "MAX_2D": expected_non_zeros = batch * depth * height * width * channels / ( stride_h * stride_w) self.assertEqual(np.count_nonzero(actual), expected_non_zeros) elif pooling_method == "AVG_2D": expected = np.ones((batch, depth, height, width, channels), dtype=np.float32) / stride_h / stride_w self.assertAllClose(actual, expected) elif pooling_method == "MAX_3D": expected_non_zeros = batch * depth * height * width * channels / ( stride_d * stride_h * stride_w) self.assertEqual(np.count_nonzero(actual), expected_non_zeros) elif pooling_method == "AVG_3D": expected = np.ones( (batch, depth, height, width, channels), dtype=np.float32) / stride_d / stride_h / stride_w self.assertAllClose(actual, expected)