def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows_size", image_height) cols_dim = mtf.Dimension("cols_size", image_width) channel_dim = mtf.Dimension("image_channel", num_channels) classes_dim = mtf.Dimension(name='classesnum',size=classesnum) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, image_height, image_width, num_channels]), mtf.Shape( [batch_dim, rows_dim, cols_dim, channel_dim])) # x = mtf.transpose(x, [batch_dim, rows_dim, cols_dim, channel_dim]) # print(x.shape) logits = VGG(x, classes_dim=classes_dim,depth=depth) logits = mtf.cast(logits,dtype=tf.float32) if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def get_bspline_kernel(x, channels, transpose=False, dtype=tf.float32, order=4): """Creates a 5x5x5 b-spline kernel. Args: num_channels: The number of channels of the image to filter. dtype: The type of an element in the kernel. Returns: A tensor of shape `[5, 5, 5, num_channels, num_channels]`. """ mesh = x.mesh in_dim = x.shape[-1] num_channels = channels.size if order == 8: kernel = np.array(( 1., 8., 28., 56., 70., 56., 28., 8., 1.), dtype=dtype.as_numpy_dtype()) elif order == 6: kernel = np.array(( 1., 6., 15., 20., 15., 6., 1.), dtype=dtype.as_numpy_dtype()) elif order==2: kernel = np.array(( 1., 2., 1.), dtype=dtype.as_numpy_dtype()) else: kernel = np.array(( 1., 4., 6., 4., 1.), dtype=dtype.as_numpy_dtype()) size = len(kernel) kernel = np.einsum('ij,k->ijk', np.outer(kernel, kernel), kernel) kernel /= np.sum(kernel) kernel = kernel[:, :, :, np.newaxis, np.newaxis] kernel = tf.constant(kernel, dtype=dtype) * tf.eye(num_channels, dtype=dtype) fd_dim = mtf.Dimension("fd", size) fh_dim = mtf.Dimension("fh", size) fw_dim = mtf.Dimension("fw", size) if transpose: return mtf.import_tf_tensor(mesh, kernel, shape=[fd_dim, fh_dim, fw_dim, channels, in_dim]) else: return mtf.import_tf_tensor(mesh, kernel, shape=[fd_dim, fh_dim, fw_dim, in_dim, channels])
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ # tf_images is a tf.Tensor with shape [batch, 28, 28] and dtype tf.float32 # tf_labels is a tf.Tensor with shape [batch] and dtype tf.int32 batch_dim = mtf.Dimension("batch", 100) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) hidden_dim = mtf.Dimension("hidden", 1024) classes_dim = mtf.Dimension("classes", 10) images = mtf.import_tf_tensor(mesh, image, shape=[batch_dim, rows_dim, cols_dim]) labels = mtf.import_tf_tensor(mesh, labels, [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim]) w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim]) # einsum is a generalization of matrix multiplication (see numpy.einsum) hidden = mtf.relu( mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim])) logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim]) loss = mtf.reduce_mean( mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim)) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28]), [batch_dim, rows_dim, cols_dim]) y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, classes_dim]) b1 = mtf.get_variable(mesh, "b1", [classes_dim]) logits = mtf.relu(mtf.einsum([x, w1], [batch_dim, classes_dim]) + b1) if labels is None: loss = None else: loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(y, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def model_backbone(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ id_hldr, wt_hldr = features batch_dim = mtf.Dimension("batch", args_opt.batch_size) field_dim = mtf.Dimension("field", size=39) vocab_dim = mtf.Dimension("vocab_size", 200000) embed_dim = mtf.Dimension("embed_size", 80) outdim = mtf.Dimension("outdim", 1) id_hldr = mtf.import_tf_tensor( mesh, tf.reshape(id_hldr, [args_opt.batch_size, field_dim.size]), mtf.Shape([batch_dim, field_dim])) wt_hldr = mtf.import_tf_tensor( mesh, tf.reshape(wt_hldr, [args_opt.batch_size, field_dim.size]), mtf.Shape([batch_dim, field_dim])) if args_opt.fp16: float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16) # id_hldr=mtf.cast(id_hldr,dtype=tf.int32) wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16) else: float16 = None logits, embedding_table = network[args_opt.model](id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=float16) logits = mtf.cast(logits, dtype=tf.float32) embedding_table = mtf.cast(embedding_table, dtype=tf.float32) if labels is None: wide_loss = None deep_loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [args_opt.batch_size]), mtf.Shape([batch_dim])) wide_loss = mtf.layers.sigmoid_cross_entropy_with_logits( logits, labels) deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2 deep_loss = mtf.reduce_mean(wide_loss) + 8e-5 * deep_loss wide_loss = mtf.reduce_mean(wide_loss) return logits, wide_loss + deep_loss
def mnist_model(image, labels, mesh, hs_t): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh hs_t: a mtf.Tensor with shape [batch, hidden_1] Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] hs_t: an updated mtf.Tensor """ input_num = 28 timesteps_num = 28 classes_num = 10 batch_dim = mtf.Dimension("batch", FLAGS.batch_size) input_dim = mtf.Dimension("input", input_num) timesteps_dim = mtf.Dimension("timesteps", timesteps_num) classes_dim = mtf.Dimension("classes", classes_num) hidden_dim_1 = mtf.Dimension("hidden_1", FLAGS.hidden_size) hidden_dim_2 = mtf.Dimension("hidden_2", FLAGS.hidden_size) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28]), [batch_dim, timesteps_dim, input_dim]) y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), [batch_dim]) hs_t = mtf.import_tf_tensor(mesh, hs_t, [batch_dim, hidden_dim_1]) Wxh = mtf.get_variable(mesh, "Wxh", [input_dim, hidden_dim_2]) Whh = mtf.get_variable(mesh, "Whh", [hidden_dim_1, hidden_dim_2]) Why = mtf.get_variable(mesh, "Why", [hidden_dim_2, classes_dim]) bh = mtf.get_variable(mesh, "bh", [hidden_dim_2]) by = mtf.get_variable(mesh, "by", [classes_dim]) x_list = mtf.unstack(x, timesteps_dim) for xs_t in x_list: hs_t = mtf.tanh( mtf.einsum([xs_t, Wxh], [batch_dim, hidden_dim_2]) + mtf.einsum([hs_t, Whh], [batch_dim, hidden_dim_2]) + bh) logits = mtf.einsum([hs_t, Why], [batch_dim, classes_dim]) + by if labels is None: loss = None else: loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(y, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss, hs_t
def testMultiheadAttention(self, kv_channels, heads): batch = 2 length = 8 channels = 3 query = tf.random_normal([batch, length, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) channels_dim = mtf.Dimension("channels", channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_dim, channels_dim])) mtf_outputs = mtf.layers.multihead_attention( mtf_query, memory_antecedent=None, mask=None, kv_channels=kv_channels_dim, heads=heads_dim) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, query.shape)
def testConv1d(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") filter_size = 3 depth_dim = mtf.Dimension("depth", 2) length_dim = mtf.Dimension("length", 4) output_dim = mtf.Dimension("output", 2) x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32) mtf_x = mtf.import_tf_tensor( mesh, x, shape=mtf.Shape([length_dim, depth_dim])) initializer_mock = mock.MagicMock() initializer_mock.side_effect = initialize_by_shape({ (1, 3, 2, 2): [[[[1, -1], [0, 0]], [[2, -2], [-1, 1]], [[3, -3], [-2, 2]]]], }) mtf_output = mtf.layers.conv1d( mtf_x, output_dim=output_dim, filter_size=filter_size, filter_initializer=initializer_mock) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_output = lowering.export_to_tf_tensor(mtf_output) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate(actual_output) self.assertAllClose(actual, [[0, 0], [1, -1], [5, -5], [4, -4]])
def testDense(self, units, use_bias): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) depth_dim = mtf.Dimension("depth", units) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf.layers.dense(mtf_inputs, output_dim=depth_dim, reduced_dims=[channels_dim], activation=mtf.relu, use_bias=use_bias) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.keras.layers.Dense(units=units, activation=tf.nn.relu, use_bias=use_bias)(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def test_hidden_to_logits_computesLogitsCorrectly(self): seq_len = 1 vocab_size = 4 model_size = 3 num_softmaxes = 2 vocab_dim = mtf.Dimension('vocab', vocab_size) model_dim = mtf.Dimension('model', model_size) length_dim = mtf.Dimension('length', seq_len) embeddings = tf.constant(np.array([[1.0, 1.0, 2.0]]) / model_size**-0.5, dtype=tf.float32) mtf_embeddings = mtf.import_tf_tensor(self.mesh, embeddings, shape=mtf.Shape( [length_dim, model_dim])) self.initializer_mock.side_effect = initialize_by_shape({ # Embedding weights. (4, 3): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]], # Mixture weights. (2, 3): [[1, 0, 0], [0, 1, 1]], # Context weights (2, 3, 3): [ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]], ], }) vocab_embedding = vocab_embeddings.MixtureOfSoftmaxes( self.mesh, vocab_dim, output_dim=model_dim, variable_dtype=self.variable_dtype, name='embedding', ensemble_dim=None, num_softmaxes=num_softmaxes) mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings, context=None) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=['']) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) actual_logits = lowering.export_to_tf_tensor(mtf_logits) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual, = self.evaluate([actual_logits]) expected_priors = scipy.special.softmax([1, 3]) expected_probs_1 = scipy.special.softmax(np.tanh([1, 1, 2, 2])) expected_probs_2 = scipy.special.softmax(np.tanh([2, 1, 1, 1])) expected_probs = (expected_priors[0] * expected_probs_1 + expected_priors[1] * expected_probs_2) expected_logits = np.log(expected_probs) self.assertAllClose(actual, [expected_logits])
def testDenseReluDense(self): batch = 2 channels = 3 hidden = 5 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) hidden_dim = mtf.Dimension("hidden", hidden) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.dense_relu_dense(mtf_inputs, hidden_channels=hidden_dim, is_training=False) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, inputs.shape)
def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels, heads, window_size): length_q = length query = tf.random_normal([batch, length_q, io_channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_q_dim = mtf.Dimension("length_q", length_q) io_channels_dim = mtf.Dimension("io_channels", io_channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim])) mtf_outputs = mtf.layers.masked_local_attention_1d( mtf_query, kv_channels=kv_channels_dim, heads=heads_dim, window_size=window_size) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, (batch, length_q, io_channels))
def testGraph(self): graph = mtf.Graph() self.assertLen(graph.operations, 0) self.assertLen(graph.tensors, 0) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertLen(graph.tensors, 1) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.tensors, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.tensors, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2)
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'num_heads:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) batch_dim = mtf.Dimension('batch', batch_size) seq_dim = mtf.Dimension('seq', seq_length) input_ids = tf.random.uniform((batch_size, seq_length), minval=0, maxval=vocab_size, dtype=tf.int32) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) model = bert_lib.BertModel(config=bert_config, is_training=True, input_ids=mtf_input_ids, input_mask=None, token_type_ids=None) pooled = model.get_pooled_output() lowering = mtf.Lowering(graph, {mesh: mesh_impl}) return lowering.export_to_tf_tensor(pooled)
def testCorr2DInput(self): batch = 4 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tfp.stats.correlation(inputs, sample_axis=0, event_axis=1) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape) self.assertAllClose(actual, expected)
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) io_dim = mtf.Dimension('io', FLAGS.io_size) master_dtype = tf.as_dtype(FLAGS.master_dtype) slice_dtype = tf.as_dtype(FLAGS.slice_dtype) activation_dtype = tf.as_dtype(FLAGS.activation_dtype) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) x = mtf.cast(x, activation_dtype) h = x for lnum in xrange(1, FLAGS.num_hidden_layers + 2): if lnum + 1 == FLAGS.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension('hidden_even', FLAGS.hidden_size) else: dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=master_dtype, slice_dtype=slice_dtype, name='layer_%d' % lnum) y = h loss = mtf.reduce_mean(mtf.square(y - x)) return y, loss
def testRecomputeGrad(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") # let's differentiate x^2 + x # dy/dx = 2x+1 def x_squared_plus_x(x): return x * x + x x = tf.constant([5, 10], dtype=tf.float32) dy = tf.constant([2, 3], dtype=tf.float32) two = mtf.Dimension("two", 2) expected_y = tf.constant([30, 110], dtype=tf.float32) expected_dx = tf.constant([22, 63], dtype=tf.float32) mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two])) mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two])) mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x]) [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy]) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="processors:2", layout="two:processors", devices=["", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_y = lowering.export_to_tf_tensor(mtf_y) actual_dx = lowering.export_to_tf_tensor(mtf_dx) self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y)) self.assertAllEqual(self.evaluate(actual_dx), self.evaluate(expected_dx))
def testNthSmallestReduceSecondDim(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 6) b_dim = mtf.Dimension("b", 2) inputs = tf.constant([[1, 10], [2, 9], [3, 8], [4, 7], [5, 6], [6, 5]]) n = 0 # find smallest element (n is zero-indexed) reduced_dim = b_dim expected_outputs = tf.constant([1, 2, 3, 4, 5, 5]) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([a_dim, b_dim])) mtf_outputs = mtf.nth_smallest_element( mtf_inputs, n, reduced_dim, "test_nth_smallest") mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="all:2", layout="a:all", devices=["", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) self.assertAllEqual(self.evaluate(actual_outputs), self.evaluate(expected_outputs))
def testWeightsNonzero(self): inputs = tf.constant([[3, 1, 0], [1, 0, 0]]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0]) channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.cast(tf.not_equal(inputs, 0), tf.float32) tf_group = lowering.copy_masters_to_slices() self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllEqual(actual, expected)
def testDynamicText2self_unpacked(self): batch = 2 length = 5 input_tensors = { "inputs": [[3, 1, 4, 1, 0], [1, 4, 3, 2, 1]], "targets": [[1, 1, 0, 0, 0], [9, 8, 1, 2, 1]], } expected_output_tensors = { "targets": [[3, 1, 4, 1, 1, 1, 0, 0, 0, 0], [1, 4, 3, 2, 1, 9, 8, 1, 2, 1]], } graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) input_shape = mtf.Shape([batch_dim, length_dim]) mtf_features = { k: mtf.import_tf_tensor(mesh, v, input_shape) for k, v in input_tensors.items() } mtf_outputs = utils._dynamic_text2self(mtf_features) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) for k, v in expected_output_tensors.items(): out = lowering.export_to_tf_tensor(mtf_outputs[k]) actual = self.evaluate(out) self.assertAllEqual(actual, v)
def testDotProductAttention(self, batch, heads, length_q, length_kv, depth_k, depth_v): query = tf.random_normal([batch, heads, length_q, depth_k]) key = tf.random_normal([batch, heads, length_kv, depth_k]) value = tf.random_normal([batch, heads, length_kv, depth_v]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) heads_dim = mtf.Dimension("heads", heads) length_q_dim = mtf.Dimension("length_q", length_q) length_kv_dim = mtf.Dimension("length_kv", length_kv) depth_k_dim = mtf.Dimension("depth_k", depth_k) depth_v_dim = mtf.Dimension("depth_v", depth_v) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, heads_dim, length_q_dim, depth_k_dim])) mtf_key = mtf.import_tf_tensor( mesh, key, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_k_dim])) mtf_value = mtf.import_tf_tensor( mesh, value, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_v_dim])) mtf_outputs = mtf.layers.dot_product_attention(mtf_query, mtf_key, mtf_value, mask=None, is_training=False) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual = self.evaluate(actual_outputs) self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))
def generate_heterogeneous_expert_masks(mask_info, num_experts, experts_dim, mesh, expert_width): """Generates the heterogeous expert masks. Example mask_info format: mask_info = [{'percent_number': .5, 'layers': 1, 'width':1}, {'percent_number': .5, 'layers': 2, 'width':2}] Args: mask_info: list of dicts. num_experts: number of experts in the model experts_dim: mtf dimension for experts (partitioned) mesh: mesh object expert_width: int, default expert width which will be modified by the mask Returns: mask of shape [moe_num_layers, num_experts, hidden_size]. """ # Get max num layers max_layers = max([m["layers"] for m in mask_info]) # Get max width max_width = max([m["width"] for m in mask_info]) # Will be shape [max_width, max_layers, num_experts] expert_mask = np.zeros([max_width, max_layers, 0]) for idx, mask_i in enumerate(mask_info): if mask_i["percent_number"] < 1.0: num_experts_in_mask = int(num_experts * mask_i["percent_number"]) else: num_experts_in_mask = int(mask_i["percent_number"]) # if percent_number=1 either because homogeneous experts or just 1 expert # in which case num_experts_in_mask will be reset to num_experts # creating one homogeneous group if idx == (len(mask_info) - 1): # Last position num_experts_in_mask_tmp = num_experts - expert_mask.shape[2] if num_experts_in_mask_tmp != num_experts_in_mask: tf.logging.info( "Expert layer probabilities do not evenly divide " "the number of experts: {} {}".format( num_experts_in_mask, num_experts_in_mask_tmp)) num_experts_in_mask = num_experts_in_mask_tmp mask = np.zeros([int(max_width), int(max_layers), num_experts_in_mask]) # Zero out the last layers of the experts. mask[:(mask_i["width"] * expert_width), :mask_i["layers"], :] = 1 expert_mask = np.concatenate([expert_mask, mask], axis=2) # expert dim assert expert_mask.shape[2] == num_experts tf.logging.info("heterogeneous mask: {}".format(expert_mask)) # Now import the numpy mask into Mesh TF. layers_dim = mtf.Dimension("num_expert_layers", max_layers) width_dim = mtf.Dimension("expert_hidden", max_width) expert_mask_tf = tf.convert_to_tensor(expert_mask) expert_mask_mtf = mtf.import_tf_tensor( mesh, tf_tensor=expert_mask_tf, shape=[width_dim, layers_dim, experts_dim]) return expert_mask_mtf
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size): graph = mtf.Graph() meshes = [] mesh_to_impl = {} mesh = mtf.Mesh(graph, 'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], gpus_per_node=num_gpus // num_nodes) assert len(inputs.shape) == 2 assert inputs.shape == labels.shape shape = utils.ConvertToShape([('axis0', batch_size), inputs.shape.as_list()[1]]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape) mtf_labels = mtf.import_tf_tensor(mesh, labels, shape) return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def test_ids_to_embedding_correctlyEmbeds(self): seq_len = 5 vocab_size = 5 model_size = 2 gate_embedding_size = 1 frequent_token_fraction = 0.4 vocab_dim = mtf.Dimension('vocab', vocab_size) model_dim = mtf.Dimension('model', model_size) length_dim = mtf.Dimension('length', seq_len) context = mock.MagicMock() context.train = False ids = tf.constant([0, 1, 2, 3, 4], dtype=tf.int32) mtf_ids = mtf.import_tf_tensor( self.mesh, ids, shape=mtf.Shape([length_dim])) self.initializer_mock.side_effect = initialize_by_shape({ # Embedding weights. (5, 2): list(range(10)), # Context weights. (4, 2, 2): list(range(16)), # Prior weights. (3, 1, 2): list(range(6)), # Prior vocab vector. (2, 1): list(range(2)), # Prior gates vector. (3, 2): list(range(6)), # Prior bias. (2, 3): list(range(6)), }) vocab_embedding = vocab_embeddings.Mixtape( self.mesh, vocab_dim, output_dim=model_dim, variable_dtype=self.variable_dtype, name='embedding', ensemble_dim=None, gate_embedding_size=gate_embedding_size, frequent_token_fraction=frequent_token_fraction) mtf_embedding = vocab_embedding.ids_to_embedding(mtf_ids, context=None) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=['']) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) actual_embedding = lowering.export_to_tf_tensor(mtf_embedding) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate([actual_embedding])[0] self.assertAllClose(actual, np.reshape(list(range(10)), (5, 2)))
def test_hidden_to_logits_computesLogitsCorrectly(self): seq_len = 4 vocab_size = 5 model_size = 2 vocab_dim = mtf.Dimension('vocab', vocab_size) model_dim = mtf.Dimension('model', model_size) length_dim = mtf.Dimension('length', seq_len) embeddings = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32) mtf_embeddings = mtf.import_tf_tensor(self.mesh, embeddings, shape=mtf.Shape( [length_dim, model_dim])) self.initializer_mock.side_effect = initialize_by_shape({ (2, 2): [[0, 1], [2, 0]], (3, 1): [[1], [2], [3]], (1, 2): [[1], [2]], }) vocab_embedding = vocab_embeddings.AdaptiveVocabEmbedding( self.mesh, vocab_dim, output_dim=model_dim, variable_dtype=self.variable_dtype, name='embedding', ensemble_dim=None, clusters=[{ 'token_count': 2, 'embedding_size': 2 }, { 'token_count': 3, 'embedding_size': 1 }]) mtf_logits = vocab_embedding.hidden_to_logits(mtf_embeddings, context=None) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=['']) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) actual_logits = lowering.export_to_tf_tensor(mtf_logits) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering.copy_masters_to_slices()) actual = self.evaluate([actual_logits])[0] self.assertAllClose( actual, model_size**-0.5 * np.array([[0, 2, 1, 2, 3], [1, 0, 2, 4, 6], [1, 2, 3, 6, 9], [1, 4, 4, 8, 12]]))
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size) io_dim = mtf.Dimension('io', FLAGS.io_size) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) h = mtf.layers.dense(x, hidden_dim, name='layer1', use_bias=False) y = mtf.layers.dense(h, io_dim, name='layer2', use_bias=False) loss = mtf.reduce_sum(mtf.square(y - x)) return y, loss
def Replication2(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([1, 4]), \ mesh1:GetMeshImpl([2, 4])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithDuplicates(mtf_in_tsr, mesh1) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) io_dim = mtf.Dimension('io', FLAGS.io_size) master_dtype = tf.as_dtype(FLAGS.master_dtype) slice_dtype = tf.as_dtype(FLAGS.slice_dtype) activation_dtype = tf.as_dtype(FLAGS.activation_dtype) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) x = mtf.cast(x, activation_dtype) h = x for lnum in range(1, FLAGS.num_hidden_layers + 2): if lnum + 1 == FLAGS.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension('hidden_even', FLAGS.hidden_size) else: dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size) h = mtf.layers.dense(h, dim, use_bias=False, master_dtype=master_dtype, slice_dtype=slice_dtype, name='layer_%d' % lnum) y = h g = tf.train.get_global_step() if FLAGS.step_with_nan >= 0: # Trigger NaN in the forward pass, this is used for testing whether # MeshTensorFlow can handle occasional NaN value. y += mtf.import_tf_tensor( mesh, tf.divide( 0.0, tf.cond(tf.equal(g, FLAGS.step_with_nan), lambda: 0., lambda: 1.)), mtf.Shape([])) loss = mtf.reduce_mean(mtf.square(y - x)) return y, loss
def Split3(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([2, 2], [0, 2, 4, 6]), \ mesh1:GetMeshImpl([2, 4])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape[:2] + [('axis1', shape[2])] + shape[3:]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1, mtf_shape.dimension_names) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def NoConcatSplit(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \ mesh1:GetMeshImpl([4, 2])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape([shape[0], ('axis0', shape[1])] + shape[2:]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithConcatSplit(mtf_in_tsr, mesh1, mtf_shape.dimension_names) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def import_to_batch_by_length(x, name): return mtf.import_tf_tensor( mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name)
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", hparams.rows_size) cols_dim = mtf.Dimension("cols_size", hparams.cols_size) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) channels_dim = mtf.Dimension("channels", 3) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels*hparams.cols_size // hparams.col_blocks, hparams.num_channels]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, channels_dim])) x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, channels_dim]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters])) x = mtf.conv2d_with_blocks( x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [block - strided block layer - strided block layer] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer( inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense( out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning("Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls( hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])