def testLayout(self): # Construct a Mesh TensorFlow graph and mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") # Decide on a mesh shape. mesh_shape = mtf.convert_to_shape("m1:4,m2:2") # Compute a layout based on the graph and mesh. # Note that knowing the identity of the outputs is important to the # optimization since they cannot be freed. layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone( layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
def testLayoutAndMeshShape(self): # Same as previous test, but don't specify a 4x2 mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(mtf_graph, 8, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0) self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 4), mtf.Dimension("mesh_1", 2)]) layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape( mtf_graph, 8, [z], 1) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape)) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape)) self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 8)])
def testOptimizeLayoutRepetition(self): x1 = mtf.zeros(self.mesh, "a:10,b:5") x2 = mtf.zeros(self.mesh, "b:5,c:20") for _ in six.moves.xrange(100): mtf.einsum([x1, x2], "a:10,c:20") optimizer = self.get_layout_optimizer() self.assertGreaterEqual( len(list(optimizer._graph.get_all_operation_names())), 50) self.assertLessEqual(len(optimizer._model.Proto().variables), 50) # Same checks. layout = optimizer.solve() self.assertEqual(layout, "a:m2;c:m1") layout_value = optimizer.evaluate_layout(layout) self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;b:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;c:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;a:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;c:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("c:m1;b:m2")) self.assertEqual(layout_value, optimizer.evaluate_layout("c:m1;a:m2"))
def model(params, inputs, labels): # MTF mesh assert len(inputs.shape) == 2 graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels = CreateMeshes( inputs, labels, params.num_nodes, params.num_gpus, params.batch_size) embed_mesh, lstm0_mesh, lstm1_mesh, proj_mesh = meshes batch_dim_name, n_dim_name, k_dim_name = 'axis0', 'axis1', 'axis2' # RNN weights num_units = params.num_units w_shape = utils.ConvertToShape([(k_dim_name, 2*num_units), (n_dim_name, 4*num_units)]) rnn_w0 = mtf.get_variable(lstm0_mesh, 'rnn_w0', w_shape) rnn_w1 = mtf.get_variable(lstm1_mesh, 'rnn_w1', w_shape) # RNN initial states h_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(k_dim_name, num_units)]) c_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(n_dim_name, num_units)]) states0 = [mtf.zeros(lstm0_mesh, h_shape), mtf.zeros(lstm0_mesh, c_shape)] states1 = [mtf.zeros(lstm1_mesh, h_shape), mtf.zeros(lstm1_mesh, c_shape)] # Model - embedding vocab_dim = mtf.Dimension(k_dim_name, params.vocab_size) embed_dim = mtf.Dimension(n_dim_name, params.num_units) assert mtf_inputs.mesh == embed_mesh embedding = mtf.layers.embedding(mtf_inputs, vocab_dim, embed_dim, tf.float32) assert embedding.shape[-1].name == n_dim_name shape = embedding.shape.rename_dimension(n_dim_name, k_dim_name) embedding = mesh_trans.ReplaceMeshWithIndependentAxes( embedding, lstm0_mesh, shape.dimension_names) # Model - RNN [y] = RNNOperation(embedding, rnn_w0, rnn_w1, num_units, states=states0 + states1).outputs assert y.mesh == lstm1_mesh assert y.shape[-1].name == k_dim_name assert mesh_to_impl[proj_mesh].shape[-1] == mtf.Dimension(k_dim_name, 1) rand_dim_name = utils.RandName() y = mt.rename_dimension(y, k_dim_name, rand_dim_name) shape = y.shape.rename_dimension(rand_dim_name, k_dim_name) y = mesh_trans.ReplaceMeshWithIndependentAxes( y, proj_mesh, shape.dimension_names) # Model - Dense + loss assert y.shape[-1].name == k_dim_name vocab_dim = mtf.Dimension(n_dim_name, params.vocab_size) y = mtf.layers.dense(y, vocab_dim, reduced_dims=y.shape[-1:], use_bias=False) assert mtf_labels.mesh == proj_mesh mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( y, mtf_labels, vocab_dim) mtf_loss = mtf.reduce_mean(mtf_cross_ent) model.soft_placement = True return graph, mesh_to_impl, mtf_loss
def testOptimizeLayoutTiebreak(self): x1 = mtf.zeros(self.mesh, "a:10,b:5") x2 = mtf.zeros(self.mesh, "b:5,c:20") mtf.einsum([x1, x2], "a:10,c:20") # Rewrite mesh_shape to have a dummy dimension. self.mesh_shape = mtf.convert_to_shape("m1:4,m2:2,m3:1") optimizer = self.get_layout_optimizer() layout = optimizer.solve() self.assertEqual(layout, "a:m2;b:m3;c:m1")
def testConcatOperation(self): concat_dim1 = mtf.Dimension("concat", 5) concat_dim2 = mtf.Dimension("concat", 7) x1 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.b_dim, concat_dim1])) x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.b_dim, concat_dim2])) concat_operation = mtf.ConcatOperation([x1, x2], "concat") self.assertEqual(concat_operation.splittable_dims, frozenset(["a", "b"])) self.assertEqual(concat_operation.unsplittable_dims, frozenset(["concat"]))
def testOptimizeLayoutUnsplittable(self): x1 = mtf.zeros(self.mesh, "a:10,b:5") x2 = mtf.zeros(self.mesh, "b:5,c:20") mtf.UnstackOperation(x1, mtf.Dimension("a", 10)) mtf.UnstackOperation(x2, mtf.Dimension("c", 20)) optimizer = self.get_layout_optimizer() # No dimensions can be split, because a and c are unstack dimensions and # b has size 5 (so there are divisiblity issues). self.assertEqual(optimizer.solve(), "")
def model(params, inputs, labels): # Mtf mesh assert len(inputs.shape) == 2 graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels = CreateMeshes( inputs, labels, params.num_nodes, params.num_gpus, params.batch_size) # Embedding dimensions vocab_dim = mtf.Dimension(utils.RandName(), params.vocab_size) embed_dim = mtf.Dimension(utils.RandName(), params.num_units) batch_dim_name = mtf_inputs.shape[0].name k_dim_name = embed_dim.name n_dim_name = utils.RandName() # RNN weights num_units = params.num_units w_shape = utils.ConvertToShape( [(k_dim_name, 2*num_units), (n_dim_name, 4*num_units)]) rnn_w0 = mtf.get_variable(meshes[0], 'rnn_w0', w_shape) rnn_w1 = mtf.get_variable(meshes[1], 'rnn_w1', w_shape) # RNN initial states h_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(k_dim_name, num_units)]) c_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(n_dim_name, num_units)]) states0 = [mtf.zeros(meshes[0], h_shape), mtf.zeros(meshes[0], c_shape)] states1 = [mtf.zeros(meshes[1], h_shape), mtf.zeros(meshes[1], c_shape)] # Model embedding = mtf.layers.embedding(mtf_inputs, vocab_dim, embed_dim, tf.float32) assert embedding.mesh == meshes[2] embedding = ReplaceRNNMesh(embedding, meshes[0]).outputs[0] [y] = RNNOperation(embedding, rnn_w0, rnn_w1, num_units, states=states0+states1).outputs assert y.mesh == meshes[1] assert y.shape[0].name == 'axis0' y = mt.rename_dimension(y, 'axis0', mtf_labels.shape[0].name) y = mesh_trans.ReplaceMeshWithSimpleReplication(y, meshes[2]) vocab_dim = mtf.Dimension('axis0', params.vocab_size) y = mtf.layers.dense(y, vocab_dim, reduced_dims=y.shape[-1:], use_bias=False) assert y.mesh == mtf_labels.mesh mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( y, mtf_labels, vocab_dim) mtf_loss = mtf.reduce_mean(mtf_cross_ent) model.soft_placement = True return graph, mesh_to_impl, mtf_loss
def testConv2dOperations(self): conv_input = mtf.zeros( self.mesh, mtf.Shape([ self.batch_dim, self.grid_h_dim, self.grid_w_dim, self.in_dim ])) conv_filter = mtf.zeros( self.mesh, mtf.Shape([ self.filter_h_dim, self.filter_w_dim, self.in_dim, self.out_dim ])) strides = [1, 1, 1, 1] padding = "SAME" conv2d_operation = mtf.Conv2dOperation(conv_input, conv_filter, strides, padding) self.assertEqual(conv2d_operation.splittable_dims, frozenset(["batch", "in", "out"])) self.assertEqual( conv2d_operation.unsplittable_dims, frozenset(["filter_h", "filter_w", "grid_h", "grid_w"])) output = conv2d_operation.outputs[0] d_output = mtf.zeros(self.mesh, output.shape) conv2d_backprop_input_operation = mtf.Conv2or3dBackpropInputOperation( 2, False, conv_input.shape, conv_filter, d_output, strides, padding) self.assertEqual( conv2d_backprop_input_operation.splittable_dims, frozenset([ "batch", "filter_h", "filter_w", "grid_h", "grid_w", "in", "out" ])) self.assertEqual(conv2d_backprop_input_operation.unsplittable_dims, frozenset()) conv2d_backprop_filter_operation = mtf.Conv2or3dBackpropFilterOperation( 2, False, conv_input, conv_filter.shape, d_output, strides, padding) self.assertEqual( conv2d_backprop_filter_operation.splittable_dims, frozenset([ "batch", "filter_h", "filter_w", "grid_h", "grid_w", "in", "out" ])) self.assertEqual(conv2d_backprop_filter_operation.unsplittable_dims, frozenset())
def testEinsumOperation(self): x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.c_dim])) einsum_operation = mtf.EinsumOperation([self.x, x2], mtf.Shape([self.b_dim, self.c_dim])) self.assertEqual(einsum_operation.splittable_dims, frozenset(["a", "b", "c"])) self.assertEqual(einsum_operation.unsplittable_dims, frozenset())
def testOptimizeLayout(self): x1 = mtf.zeros(self.mesh, "a:10,b:5") x2 = mtf.zeros(self.mesh, "b:5,c:20") mtf.einsum([x1, x2], "a:10,c:20") optimizer = self.get_layout_optimizer() # Cut dimensions to make them equally sized. layout = optimizer.solve() self.assertEqual(layout, "a:m2;c:m1") # This optimal layout should have the lowest value. layout_value = optimizer.evaluate_layout(layout) self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;b:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;c:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;a:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;c:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("c:m1;b:m2")) self.assertEqual(layout_value, optimizer.evaluate_layout("c:m1;a:m2"))
def testOneHotOperation(self): x = mtf.zeros(self.mesh, self.ab_shape, dtype=tf.int32) one_hot_operation = mtf.OneHotOperation(x, self.c_dim, 1, 0, dtype=tf.bool) self.assertEqual(one_hot_operation.splittable_dims, frozenset(["a", "b", "c"])) self.assertEqual(one_hot_operation.unsplittable_dims, frozenset())
def setUp(self): super(OperationSplittabilityTest, self).setUp() self.graph = mtf.Graph() self.mesh = mtf.Mesh(self.graph, "my_mesh") self.a_dim = mtf.Dimension("a", 5) self.b_dim = mtf.Dimension("b", 10) self.c_dim = mtf.Dimension("c", 15) self.ab_shape = mtf.Shape([self.a_dim, self.b_dim]) self.x = mtf.zeros(self.mesh, self.ab_shape) self.batch_dim = mtf.Dimension("batch", 100) self.grid_h_dim = mtf.Dimension("grid_h", 10) self.grid_w_dim = mtf.Dimension("grid_w", 10) self.filter_h_dim = mtf.Dimension("filter_h", 5) self.filter_w_dim = mtf.Dimension("filter_w", 5) self.in_dim = mtf.Dimension("in", 10) self.out_dim = mtf.Dimension("out", 10) self.image = mtf.zeros(self.mesh, [self.batch_dim, self.grid_h_dim, self.grid_w_dim, self.in_dim])
def setUp(self): super(LayoutValidatorTest, self).setUp() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 5) b_dim = mtf.Dimension("b", 10) concat_dim1 = mtf.Dimension("concat", 15) concat_dim2 = mtf.Dimension("concat", 20) x1 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim1])) x2 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim2])) mtf.ConcatOperation([x1, x2], "concat") # We add a tensor with anonymous shape, which is supposed to be # unsplittable (i.e. none of its dimensions show up during # test_SplittableMtfDimensionNames). _ = mtf.zeros(mesh, mtf.anonymous_shape(mtf.Shape([a_dim, b_dim]))) mesh_shape = mtf.Shape([("m1", 4), ("m2", 2)]) self.valid_layouts = valid_layouts.LayoutValidator(graph, mesh_shape)
def testBinaryOpWithBroadcasting(self): x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.c_dim])) binary_op_with_broadcasting = mtf.BinaryOpWithBroadcasting( tf.less, self.x, x2, mtf.Shape([self.a_dim, self.b_dim, self.c_dim]), tf.bool, name="less with broadcasting") self.assertEqual(binary_op_with_broadcasting.splittable_dims, frozenset(["a", "b", "c"])) self.assertEqual(binary_op_with_broadcasting.unsplittable_dims, frozenset())
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape([self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append( (1.0 - mtf.to_float(mtf.replace_dimensions( input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate(self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count(layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias)
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad(inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum([encoder_output, k_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) v = mtf.einsum([encoder_output, v_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError("hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim ]) local_kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, local_attention_window, self.kv_dim ]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend([ mtf.zeros( mesh, local_kv_shape, dtype=self.activation_dtype) ] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode(logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def force_single(state, lr_shape, hr_shape, kvec_lr, halo_size, cosmology=Planck15, pm_nc_factor=1, **kwargs): """ Estimate force on the particles given a state. Parameters: ----------- state: tensor Input state tensor of shape (3, batch_size, npart, 3) boxsize: float Size of the simulation volume (Mpc/h) TODO: check units cosmology: astropy.cosmology Cosmology object pm_nc_factor: int TODO: @modichirag please add doc """ X, P, F = state #TODO: support different factor assert pm_nc_factor == 1 lnc = lr_shape[-1].size part_shape = X.shape # Paint the particles on the high resolution mesh field = mtf.zeros(X.mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) field = mesh_utils.cic_paint(field, X, halo_size) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: field = mtf.slice(field, halo_size, block_size_dim.size, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) # Reorder the low res FFTs which where transposed# y,z,x kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]] displacement = [] for f in kfield_lr: f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4] + [ mtf.Dimension('sx_block', lnc // hr_shape[1].size), mtf.Dimension('sy_block', lnc // hr_shape[2].size), mtf.Dimension('sz_block', lnc // hr_shape[3].size) ]), name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size) d = mesh_utils.cic_readout(f, X, halo_size) displacement.append(d) # Readout the force to particle positions F = mtf.stack([d for d in displacement], "ndim", axis=4) F = F * 1.5 * cosmology.Om0 return X, P, F
def force(state, lr_shape, hr_shape, kvec_lr, kvec_hr, halo_size, cosmology=Planck15, downsampling_factor=2, pm_nc_factor=1, antialias=True, **kwargs): """ Estimate force on the particles given a state. Parameters: ----------- state: tensor Input state tensor of shape (3, batch_size, npart, 3) boxsize: float Size of the simulation volume (Mpc/h) TODO: check units cosmology: astropy.cosmology Cosmology object pm_nc_factor: int TODO: @modichirag please add doc """ X, P, F = state #TODO: support different factor assert pm_nc_factor == 1 lnc = lr_shape[-1].size part_shape = X.shape k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_hr = [d.shape[0] for d in kvec_hr] # Reorder the FFTs which where transposed# y,z,x k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]] # Paint the particles on the high resolution mesh field = mtf.zeros(X.mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) field = mesh_utils.cic_paint(field, X, halo_size) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size) # Split the field into low and high resolution field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) hr_field = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr) kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield, kvec_lr, r_split=0) kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield, kvec_hr, r_split=0) kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr) # Reorder the low res FFTs which where transposed# y,z,x kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]] kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]] displacement = [] for f, g in zip(kfield_lr, kfield_hr): f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4] + [ mtf.Dimension('sx_block', lnc // hr_shape[1].size), mtf.Dimension('sy_block', lnc // hr_shape[2].size), mtf.Dimension('sz_block', lnc // hr_shape[3].size) ]), name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [ halo_size // 2**downsampling_factor, halo_size // 2**downsampling_factor ], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size // 2**downsampling_factor) f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)]) f = mesh_utils.upsample(f, downsampling_factor) f = mtf.reshape(f, f.shape[:-1]) g = mesh_utils.c2r3d(g, f.shape[-3:]) high_shape = g.shape # And now we remove the large scales g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)]) _low = mesh_utils.downsample(g, downsampling_factor, antialias=antialias) g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor), g.shape) g = mtf.reshape(g, high_shape) d = mesh_utils.cic_readout(f + g, X, halo_size) displacement.append(d) # Readout the force to particle positions F = mtf.stack([d for d in displacement], "ndim", axis=4) F = F * 1.5 * cosmology.Om0 return X, P, F
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, fully_autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, write_priority=None, read_priority=None, attention_kwargs=None): """Attention to the a neighborood around the source. If fully_autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not fully_autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. In addition, if write_priority and read_priority are provided, then attention is limited to position pairs where read_priority[query position] >= write_priority[memory position] Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) fully_autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer write_priority: an optional Tensor containing length_dim read_priority: an optional Tensor containing length_dim attention_kwargs: optional keyword arguments for attention() Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if fully_autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos visible = mtf.equal(q_sequence_id, m_sequence_id) visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius)) visible = mtf.logical_and(visible, mtf.less_equal( relative_position, 0 if fully_autoregressive else radius)) if read_priority is not None: write_priority = _reshape_memory(write_priority) read_priority = _reshape_query(read_priority) visible = mtf.logical_and( visible, mtf.greater_equal(read_priority, write_priority)) bias = visibility_mask_to_attention_bias(visible, q.dtype) o = attention(q, k, v, padded_memory_block_length, key_dim, value_dim, bias, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
def fn(x): with tf.variable_scope(scope): nx = x.shape[-1] # Grab last dimension from input if use_rezero: prenorm = identity elif use_scale_norm: prenorm = scale_norm else: prenorm = layer_norm pre_residual_fn = rezero if use_rezero else identity attention_type = params["attention_types"][layer_num] if macaron_attention: mult = 0.5 mlp_fn = mlp_glu if use_mlp_glu else mlp intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2) # Define intermediate layer of mlp - to split dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size) m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) x = x + (m * mult) else: mult = 1 if attention_type != "none": res_x = prenorm(x, "norm_1", variable_dtype=variable_dtype, params=params) a = attn(res_x, "attn", nx, attention_type=attention_type, params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim, variable_dtype=variable_dtype, context=context) else: a = x x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype) res_x = prenorm(x, "norm_2", variable_dtype=variable_dtype, params=params) if use_moe: moe_params = mtf.transformer.moe.HParams() mtf.transformer.moe.set_default_moe_hparams(moe_params) # Override defaults for k, v in params["moe_params"].items(): moe_params.add_hparam(k, v) moe_train = params["mode"] == "train" m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(res_x, x.shape[-1], moe_params, train=moe_train, mesh_shape=params["mesh_shape"], layout=params["layout"], activation=params.get("moe_activation", "relu"), variable_dtype=variable_dtype) m = mtf.dropout(m, rate=params["res_dropout"], name="moe_dropout") else: mlp_fn = mlp_glu if use_mlp_glu else mlp intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2) # Define intermediate layer of mlp - to split dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size) m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype) x = x + pre_residual_fn((m*mult), "norm_rezero_2", variable_dtype) return x, aux_loss
def nbody_prototype(mesh, infield=False, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize # Parameters of the large scales decomposition fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin, shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation ## Compute initial initial conditions distributed input_field = tf.placeholder(dtype, [batch_size, nc, nc, nc]) if infield: initc = mtf.import_tf_tensor(mesh, input_field, shape=part_shape) else: initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( initc, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return initc, final_field, input_field
def nbody_fn(mesh, klin, plin, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Pyramid N-body function """ stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize # Parameters of the large scales decomposition downsampling_factor = FLAGS.dsample lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor( mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor( mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor( mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False) kx_lr = mtf.import_tf_tensor( mesh, kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor( mesh, kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor( mesh, kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False) kx_hr = mtf.import_tf_tensor( mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor( mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor( mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Compute initial initial conditions distributed initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) # Reshaping array into high resolution mesh field = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [initc], output_dtype=tf.float32, output_shape=hr_shape, name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise( lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) state = mtfpm.lpt_init( low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True, ) final_state = mtfpm.nbody( state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return initc, final_field
def act_layer(self, context, x, mask): """Build a Universal Transformer ACT layer.""" state = x act_max_steps = self.act_max_steps threshold = 1.0 - self.act_epsilon state_shape_static = state.shape.dims state_slice = slice(0, 3) if self.act_type == "global": state_slice = slice(0, 2) # Dynamic shape for update tensors below update_shape = state_shape_static[state_slice] # Halting probabilities (p_t^n in the paper) halting_probability = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Remainders (R(t) in the paper) remainders = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Number of updates performed (N(t) in the paper) n_updates = mtf.zeros(context.mesh, update_shape, dtype=context.activation_dtype) # Previous cell states (s_t in the paper) previous_state = mtf.zeros_like(state) step = mtf.constant(context.mesh, 0, dtype=tf.int32) def ut_function(state, step, halting_probability, remainders, n_updates, previous_state): """implements act (position-wise halting). Args: state: 3-D Tensor: [batch_size, length, channel] step: indicates number of steps taken so far halting_probability: halting probability remainders: act remainders n_updates: act n_updates previous_state: previous state Returns: transformed_state: transformed state step: step+1 halting_probability: halting probability remainders: act remainders n_updates: act n_updates new_state: new state """ state = self.step_preprocess(context, state, step) if self.act_type == "random": # random as halting probability p = mtf.random_uniform(context.mesh, shape=halting_probability.shape.dims, dtype=context.variable_dtype) else: last_dim_name = state.shape.dimension_names[-1] new_dims = [mtf.Dimension(last_dim_name, 1)] with tf.variable_scope("sigmoid_activation_for_pondering", reuse=tf.AUTO_REUSE): p = mtf.layers.dense(state, variable_dtype=context.variable_dtype, reduced_dims=[state.shape.dims[-1]], new_dims=new_dims, activation=mtf.sigmoid, use_bias=True) if self.act_type == "global": # average over all positions (as a global halting prob) p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1]) p = mtf.squeeze(p) else: # maintain position-wise probabilities new_shape = p.shape.dims[:-1] p = mtf.reshape(p, new_shape) # Mask for inputs which have not halted yet still_running = mtf.cast(mtf.less(halting_probability, 1.0), context.activation_dtype) # Mask of inputs which halted at this step new_halted = mtf.cast( mtf.greater(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Mask of inputs which haven't halted, and didn't halt this step still_running = mtf.cast( mtf.less_equal(halting_probability + p * still_running, threshold), context.activation_dtype) * still_running # Add the halting probability for this step to the halting # probabilities for those input which haven't halted yet halting_probability += p * still_running # Compute remainders for the inputs which halted at this step remainders += new_halted * (1 - halting_probability) # Add the remainders to those inputs which halted at this step halting_probability += new_halted * remainders # Increment n_updates for all inputs which are still running n_updates += still_running + new_halted # Compute the weight to be applied to the new state and output # 0 when the input has already halted # p when the input hasn't halted yet # the remainders when it halted this step input_tensor = p * still_running + new_halted * remainders update_weights = input_tensor # apply transformation on the state transformed_state = state for _ in range(self.num_inrecurrence_layers): transformed_state = self.vanilla_transformer_layer( context, transformed_state, mask) # update running part in the weighted state and keep the rest new_state = ((transformed_state * update_weights) + (previous_state * (1 - update_weights))) if self.act_type == "accumulated": # Add in the weighted state new_state = (transformed_state * update_weights) + previous_state step += 1 return (transformed_state, step, halting_probability, remainders, n_updates, new_state) for _ in range(act_max_steps + 1): (state, step, halting_probability, remainders, n_updates, previous_state) = ut_function(state, step, halting_probability, remainders, n_updates, previous_state) ponder_times = n_updates mtf.scalar_summary("ponder_times", mtf.reduce_mean(ponder_times)) return previous_state
def transformer_moe_layer_v1(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None): """Local mixture of experts that works well on TPU. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() The number of parameters in the gating network is: (input_dim.size * hparams.num_experts) + The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: <B>: batch dims L: original sequence length M: input depth N: output depth G: number of groups S: group size E: number of experts C: expert capacity (u for unsplit dims) Args: inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional Tensor with shape [<batch_dims>, length_dim] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). Returns: outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ # See "Dimensions cheat sheet" # <B>LM Tensor orig_inputs = inputs hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) experts_dim = mtf.Dimension("experts", hparams.moe_num_experts) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups is a multiple of the mesh dimension # over which those groups are split. batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1]) # Hack: we assume that # "outer_batch" == replication of experts # mesh_dim_size can be derived from mesh_shape and orig_batch_dim # # We then reqire num_groups to be a multiple of mesh_dim_size. if orig_inputs.shape.dims[0].name == "outer_batch": outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2] else: outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1), orig_inputs.shape.dims[0]) # Number of MoE inputs (total number of position across batch_and_length_dims # per replica. n = 1 for d in batch_and_length_dims: n *= d.size n = n // outer_batch_dim.size mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, orig_batch_dim) num_groups, group_size = _split_into_groups(n, hparams.moe_group_size, mesh_dim_size) group_size_dim = mtf.Dimension("group", group_size) num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups) moe_input_dims = [ outer_batch_dim, num_groups_dim, group_size_dim, input_dim ] # OGSM Tensor inputs = mtf.reshape(inputs, moe_input_dims) # Each sequence sends expert_capacity positions to each expert. if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min( group_size_dim.size, int((group_size_dim.size * capacity_factor) / experts_dim.size)) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size) if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1]) if hparams.moe_gating == "top_2": # dispatch_tensor and combine_tensor are # <B>GSEC Tensors dispatch_tensor, combine_tensor, loss = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) expert_inputs = mtf.einsum([inputs, dispatch_tensor], mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, input_dim ])) expert_inputs = mtf.reshape( expert_inputs, mtf.Shape([ outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim ])) # Now feed the expert inputs through the experts. h = mtf.layers.dense(expert_inputs, hidden_dim, expert_dims=[experts_dim], activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype, name="wi") expert_output = mtf.layers.dense(h, output_dim, expert_dims=[experts_dim], use_bias=False, variable_dtype=variable_dtype, name="wo") expert_output = mtf.reshape( expert_output, mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, output_dim, ])) moe_output_dims = moe_input_dims[:-1] + [output_dim] output = mtf.einsum([expert_output, combine_tensor], mtf.Shape(moe_output_dims)) output = mtf.reshape(output, batch_and_length_dims + [output_dim]) return output, loss * hparams.moe_loss_coef
def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ # Setup parameters bs = FLAGS.box_size nc = FLAGS.cube_size batch_size = FLAGS.batch_size a0 = FLAGS.a0 a = 1.0 nsteps = FLAGS.pm_steps # Compute a few things first, using simple tensorflow klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Initialize the integration steps stages = np.linspace(FLAGS.a0, 1.0, FLAGS.pm_steps, endpoint=True) # Generate a batch of 3D initial conditions initial_conditions = flowpm.linear_field( nc, # size of the cube bs, # Physical size of the cube ipklin, # Initial power spectrum batch_size=batch_size) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) from flowpm.kernels import laplace_kernel, gradient_kernel lap = tf.cast(laplace_kernel(kvec), tf.complex64) grad_x = gradient_kernel(kvec, 0) grad_y = gradient_kernel(kvec, 1) grad_z = gradient_kernel(kvec, 2) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = 8 n_block_y = 4 n_block_z = 1 halo_size = 4 # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) state = mtfpm.lpt_init_single( initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) #state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, # part_shape[1:], downsampling_factor=downsampling_factor, antialias=True,) # Here we can run our nbody final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return mtf.reduce_sum(final_field)
def lpt_prototype(mesh, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 0 lnc = nc // 2**downsampling_factor # fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) # # Reshaping array into high resolution mesh # field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1), # [initc], # output_dtype=tf.float32, # output_shape=hr_shape, # name='my_reshape', # splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3]) # state = mtfpm.lpt_init_single( initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return initc, final_field
def transformer_moe_layer_v2(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None): """2-level mixture of experts. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_capacity_factor_second_level: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() One set of params for experts in first level and different of hparams per expert in the second level. The number of parameters in the gating network is: (input_dim.size * (hparams.num_experts) + (moe_hidden_size * hparams.num_experts) * hparams.num_experts The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-3 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: a, b: batch size l: original sequence length m: input depth n: output depth g, h: number of groups s, t: group size x, y: number of experts c, d: expert capacity input: [a0, b1, l, m] input: [a0, g1, s, m] dispatch_tensor_x: [a0, g1, s, x, c] expert_input: [a0, g1, x, c, m] alltoall: [a0, g, x1, c, m] alltoall: [a0, g, x1, c, m] transpose: [x1, a0, g, c, m] reshape: [x1, h0, s, m] assignment2: [x1, h0, t, y, d] expert_input2: [x1, h0, y, d, m] alltoall: [x1, h, y0, d, m] ... reverse of that gating params 0: [m, x] gating params 1: [x1, m, y] expert params: [x1, y0, m, hidden] [x1, y0, hidden, n] Args: inputs: a mtf.Tensor with shape [a, b, l, m] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional mtf.Tensor with shape [a, b, l] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). Returns: outputs: a Tensor with shape [a, b, l, n] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1], dtype=inputs.dtype) + nonpadding insert_outer_batch_dim = (len(inputs.shape.dims) == 3) if insert_outer_batch_dim: inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims) assert len(hparams.moe_num_experts) == 2 a0, b1, l, m = inputs.shape.dims hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0]) y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1]) x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0]) y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1]) n = output_dim # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (g.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( b1.size * l.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1)) g1 = mtf.Dimension(b1.name, num_groups) g = mtf.Dimension(b1.name + "_unsplit", g1.size) s = mtf.Dimension("group_size_x", group_size) # Each sequence sends (at most?) expert_capacity positions to each expert. # Static expert_capacity dimension is needed for expert batch sizes if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size)) expert_capacity = max(expert_capacity, 4) c = mtf.Dimension("expert_capacity_x", expert_capacity) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (h.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( a0.size * g.size * c.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0)) t = mtf.Dimension("group_size_y", group_size) h0 = mtf.Dimension(a0.name, num_groups) h = mtf.Dimension(a0.name + "_unsplit", h0.size) expert_capacity = min( t.size, int((t.size * hparams.moe_capacity_factor_second_level) / y.size)) expert_capacity = max(expert_capacity, 4) d = mtf.Dimension("expert_capacity_y", expert_capacity) # First level of expert routing # Reshape the inner batch size to a multiple of group_dim g1 and # group_size_dim s. inputs = mtf.reshape(inputs, [a0, g1, s, m]) if nonpadding is not None: nonpadding = mtf.reshape(nonpadding, [a0, g1, s]) # Get the assignments for the first level. # dispatch_tensor_x has shape [a0, g1, s, x, c] if hparams.moe_gating == "top_2": dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=x, expert_capacity_dim=c, hparams=hparams, train=train, variable_dtype=variable_dtype, name="outer_gating", importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m]) # we construct an "importance" Tensor for the inputs to the second-level # gating. The importance of an input is 1.0 if it represents the # first-choice expert-group and 0.5 if it represents the second-choice expert # group. This is used by the second-level gating. importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c]) importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) + mtf.to_float(mtf.greater(importance, 0.0))) # First level, all to all. Here we change the split dimension from g1 to x1. expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c, m])) importance = mtf.reshape(importance, [x1, a0, g, c]) # Second level of expert routing # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0 # and group_size_dim t. inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m]) importance = mtf.reshape(importance, [x1, h0, t]) # Get the assignments for the second level. # dispatch_tensor_y has shape [x1, h0, t, y, d] if hparams.moe_gating == "top_2": dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating( inputs=inputs_y, outer_expert_dims=[x1], experts_dim=y, expert_capacity_dim=d, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=importance, name="inner_gating") else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m]) # Second level, all to all. Here we change the split dimension from h0 to y0. expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d, m])) hidden_output = mtf.layers.dense(expert_inputs_y, hidden_dim, expert_dims=[y0, x1], activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype, name="wi") expert_output = mtf.layers.dense(hidden_output, output_dim, expert_dims=[y0, x1], use_bias=False, variable_dtype=variable_dtype, name="wo") # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done) # expert_output has shape [y0, x1, h, d, n] # alltoall expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n])) # combine results from inner level output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n]) # Reshape the combined tensor from inner level to now contain outer_batch_dim # a0 and group_dim g output = mtf.reshape(output_y, [x1, a0, g, c, n]) # alltoall from expert_dim x to group_dim g1 expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n])) # combine results from outer level output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n]) # Reshape the combined tensor to now contain inner_batch_dim # b1 and the original sequence length output = mtf.reshape(output_x, [a0, b1, l, n]) if insert_outer_batch_dim: output = mtf.reshape(output, [b1, l, n]) return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, sampling_keep_top_k=-1, decode_length_multiplier=1.5, decode_length_constant=10, max_decode_length=None): """Sampling or beam search for Funnel Transformer. Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sampling_keep_top_k: a value between 1 and vocab_size used to sample from only the k most likely logits. Set to -1 to sample from all logits. decode_length_multiplier: a float decode_length_constant: a float max_decode_length: an optional integer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) # The sequence_id is updated inside the layer_stack due to pooling. So we # need to use the updated sequence_id stored in the context. encoder_sequence_id = self.encoder.layer_stack.context.sequence_id encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] if max_decode_length is None: decode_length_dim = length_dim else: decode_length_dim = mtf.Dimension("length", max_decode_length) if beam_size == 1: ids_shape = mtf.Shape(batch_dims + [decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, sampling_keep_top_k=sampling_keep_top_k, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length( inputs), shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") if sampling_keep_top_k != -1: raise ValueError( "don't know how to beam search with top-k value other than -1." ) # beam search beam_dim = mtf.Dimension("beam", beam_size) ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=inputs, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, decode_length_multiplier=1.5, decode_length_constant=10): """Sampling or beam search. TODO(noam): should we make the output length dimension different from the input length dimension? Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. decode_length_multiplier: a float decode_length_constant: a float Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output) encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) if beam_size == 1: ids_shape = inputs.shape partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") # beam search beam_dim = mtf.Dimension("beam", beam_size) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)