def model(mtf_features, other_features, params, mesh, variable_dtype, context=None): """A GPT style model implemented in mesh tensorflow.""" x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs( mtf_features, other_features) if is_incremental_inference(context): # reshape inputs if in inference mode x = mtf.gather(x, context.position - 1, sequence_dim) x = mtf.reshape(x, [batch_dim]) use_axial_pos_emb = params["axial_pos_emb"] is not None if not use_axial_pos_emb: # Use standard position encoding wpe = mtf.get_variable( mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) else: wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype) # Text encoding wte = mtf.get_variable( mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.02), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) with tf.variable_scope("token_embd"): # Text embedding h = mtf.gather(wte, x, vocab_dim) if params["embed_dropout"] > 0 and params["mode"] == "train": h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout") with tf.variable_scope("pos_embd"): # Positional embedding position_indices = mtf.range( mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else ( context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) if params["embed_dropout"] > 0 and params["mode"] == "train": pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout") h += pos_emb aux_losses = 0 # instantiate auxiliary losses (for MOE models) for layer in range(params["n_layer"]): # attn blocks share_parameters = exists( params["share_parameters"]) and params["share_parameters"] == True block_scope = f"h{layer}" if not share_parameters else "" block_fn = block(params=params, scope=block_scope, layer_num=layer, bias=other_features["attn_bias"], sequence_dim=sequence_dim, memory_length_dim=other_features["memory_length_dim"], variable_dtype=variable_dtype, context=context) # If true and in train mode, enable gradient checkpointing recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad( block_fn, [h]) aux_losses += loss no_weight_tie_emb = params["no_weight_tie"] == True if no_weight_tie_emb: with tf.variable_scope("wte_final_linear"): logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params) else: # Layer normalize & affine transform h = layer_norm(h, "ln_f", variable_dtype=variable_dtype) seq_dim = sequence_dim if not is_incremental_inference( context) else mtf.Dimension("sequence", 1) with tf.variable_scope("wte_final_einsum"): # Equivalent to tf.matmul logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim]) if params["mode"] in ["train", "eval"]: labels = mtf_features["labels"] z_loss = params.get( "z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy # Go to full precision for the logits logits = mtf.cast(logits, tf.float32) use_entmax_loss = params.get("entmax_loss", False) loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits with tf.variable_scope("xentropy_final"): loss_batch = loss_fn(logits=logits, targets=labels, vocab_dim=logits.shape[-1], z_loss=z_loss) # For non-autoregressive models (masked language modeling training) # Make sure labels with padding tokens are not counted in the loss if not params["causal"]: padding_id = params.get("padding_id", 0) loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch)) with tf.variable_scope("reduce_mean_final"): loss = mtf.reduce_mean(loss_batch) loss += aux_losses # Add on auxiliary losses (currently only used for MoE) loss /= params["num_microbatches"] # Convert to train dtype loss = mtf.cast(loss, variable_dtype.slice_dtype) else: loss = None loss_batch = None # Cast back to checkpoint dtype logits = mtf.cast(logits, variable_dtype.master_dtype) return logits, loss, loss_batch
def value_dim(self): """Dimensionality of attention value.""" if self.config.attention_value_head_size is None: raise ValueError("The value head size is not defined.") return mtf.Dimension("d_v", self.config.attention_value_head_size)
def transformer_moe_layer_v1( inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None): """Local mixture of experts that works well on TPU. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() The number of parameters in the gating network is: (input_dim.size * hparams.num_experts) + The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Args: inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional Tensor with shape [<batch_dims>, length_dim] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). Returns: outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ orig_inputs = inputs hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) experts_dim = mtf.Dimension("experts", hparams.moe_num_experts) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups is a multiple of the mesh dimension # over which those groups are split. orig_batch_dim, orig_length_dim, input_dim = orig_inputs.shape.dims num_groups, group_size = _split_into_groups( orig_batch_dim.size * orig_length_dim.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, orig_batch_dim)) group_size_dim = mtf.Dimension("group", group_size) batch_dim = mtf.Dimension(orig_batch_dim.name, num_groups) inputs = mtf.reshape(inputs, [batch_dim, group_size_dim, input_dim]) # Each sequence sends expert_capacity positions to each expert. if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min( group_size_dim.size, int((group_size_dim.size * capacity_factor) / experts_dim.size)) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size) if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, [orig_batch_dim, orig_length_dim], dtype=inputs.dtype) + nonpadding nonpadding = mtf.reshape(nonpadding, [batch_dim, group_size_dim]) if hparams.moe_gating == "top_2": dispatch_tensor, combine_tensor, loss = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # put num_experts dimension first to make split easier in alltoall expert_inputs = mtf.einsum([inputs, dispatch_tensor], mtf.Shape( [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim])) expert_inputs = mtf.reshape(expert_inputs, mtf.Shape( [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim])) # Now feed the expert inputs through the experts. h = mtf.layers.dense( expert_inputs, hidden_dim, expert_dims=[experts_dim], activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype, name="wi") expert_output = mtf.layers.dense( h, output_dim, expert_dims=[experts_dim], use_bias=False, variable_dtype=variable_dtype, name="wo") expert_output = mtf.reshape(expert_output, mtf.Shape( [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim])) output = mtf.einsum([expert_output, combine_tensor], mtf.Shape( [batch_dim, group_size_dim, output_dim])) output = mtf.reshape(output, orig_inputs.shape.dims[:-1] + [output_dim]) return output, loss * hparams.moe_loss_coef
def feedforward_intermediate_dim(self): return mtf.Dimension("intermediate", self.config.feedforward_intermediate_size)
def max_position_embeddings_dim(self): return mtf.Dimension("max_position_embeddings", self.config.max_position_embeddings)
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape( [self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append((1.0 - mtf.to_float( mtf.replace_dimensions(input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append( mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate( self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count( layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) tf.logging.info( f"[haiqwa-test] first_token_tensor shape: {first_token_tensor.shape}" ) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias)
def model_dim(self): return mtf.Dimension("hidden", self.config.d_model)
def testConvertToShape(self, inputs): shape = mtf.convert_to_shape(inputs) self.assertEqual( shape, mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)]))
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list, ) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] is_real_example = None if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) num_labels_dim = mtf.Dimension("seq", num_labels) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (total_loss, per_example_loss, logits, probabilities) = create_model(bert_config, is_training, mtf_input_ids, mtf_input_mask, mtf_segment_ids, mtf_label_ids, num_labels_dim, layout_rules, mesh_shape) total_loss = mtf.anonymize(total_loss) per_example_loss = mtf.anonymize(per_example_loss) logits = mtf.anonymize(logits) if mode == tf.estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, max_optimized_variable_size=FLAGS.max_optimized_variable_size, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, } eval_metrics = (metric_fn, [ lowering.export_to_tf_tensor(per_example_loss), label_ids, lowering.export_to_tf_tensor(logits), is_real_example ]) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = bert_lib.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( mode, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.tpu.TPUEstimatorSpec( mode, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: return tf.estimator.tpu.TPUEstimatorSpec( mode, prediction_hooks=[restore_hook], predictions={ "probabilities": lowering.export_to_tf_tensor(probabilities) }, scaffold_fn=scaffold_fn)
def Alexnet(img, labels, num_nodes, num_gpus, args): num_classes = 1000 keep_prob = 0.5 learning_rate = 0.01 graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes( img, labels, num_nodes, num_gpus, args) RenameFC = lambda x: mt.rename_dimension(x, x.shape[-1].name, utils.RandName()) strategy = args.strategy if strategy == 0: fc6_units = mtf.Dimension(utils.RandName(), 4096) fc7_units = mtf.Dimension(utils.RandName(), 4096) fc8_units = mtf.Dimension(utils.RandName(), num_classes) elif strategy == 1: fc6_units = mtf.Dimension('axis1', 4096) fc7_units = mtf.Dimension('axis0', 4096) fc8_units = mtf.Dimension('axis1', num_classes) elif strategy == 2: num_classes = utils.RoundUp(num_classes, num_gpus) fc6_units = mtf.Dimension('axis0', 4096) fc7_units = mtf.Dimension('axis0', 4096) fc8_units = mtf.Dimension('axis0', num_classes) elif strategy == 3: num_classes = utils.RoundUp(num_classes, num_gpus // 2) fc6_units = mtf.Dimension('axis1', 4096) fc7_units = mtf.Dimension('axis1', 4096) fc8_units = mtf.Dimension('axis1', num_classes) with tf.variable_scope('alexnet'): # Conv1 + ReLU + maxpool1 conv1 = mt.Conv2d(mtf_img, GetFilterShape(mtf_img, (11, 11, 3, 96)), (4, 4), 'VALID', activation=mtf.relu, name='conv1') pool1 = mt.MaxPool(conv1, (3, 3), (2, 2), 'VALID', name='pool1') # Conv2 + ReLU + maxpool2 conv2 = mt.Conv2d(pool1, GetFilterShape(pool1, (5, 5, 96, 256)), (1, 1), 'SAME', activation=mtf.relu, name='conv2') pool2 = mt.MaxPool(conv2, (3, 3), (2, 2), name='pool2') # Conv3 + ReLU conv3 = mt.Conv2d(pool2, GetFilterShape(pool2, (3, 3, 256, 384)), padding='SAME', activation=mtf.relu, name='conv3') # Conv4 + ReLU conv4 = mt.Conv2d(conv3, GetFilterShape(conv3, (3, 3, 384, 384)), padding='SAME', activation=mtf.relu, name='conv4') # Conv5 + ReLU + maxpool5 conv5 = mt.Conv2d(conv4, GetFilterShape(conv4, (3, 3, 384, 256)), padding='SAME', activation=mtf.relu, name='conv5') pool5 = mt.MaxPool(conv5, (3, 3), (2, 2), name='pool5') # Rename dims if strategy == 1: k_dim = mtf.Dimension(utils.RandName(), utils.Prod(pool5.shape.to_integer_list[1:])) pool5 = mtf.reshape(pool5, mtf.Shape([pool5.shape[0], k_dim])) pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], (utils.RandName(), 'axis0')) elif strategy == 2: pool5 = mt.rename_dimension(pool5, pool5.shape[0].name, utils.RandName()) elif strategy == 3: assert pool5.shape[0].name == 'axis0' #dim_names = pool5.shape.rename_dimension('axis0', utils.RandName()) #pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], dim_names) pool5 = ReplaceMeshWithConcatSplit(pool5, meshes[1]) # FC + ReLU + dropout fc_activation = lambda x: mtf.dropout(mtf.relu(x), keep_prob) fc6 = mtf.layers.dense(pool5, fc6_units, activation=fc_activation, reduced_dims=pool5.shape[1:], name='fc6') if strategy == 2: fc6 = RenameFC(fc6) elif strategy == 3: fc6 = RenameFC(fc6) fc7 = mtf.layers.dense(fc6, fc7_units, activation=fc_activation, reduced_dims=fc6.shape.dims[-1:], name='fc7') if strategy == 2: fc7 = RenameFC(fc7) elif strategy == 3: fc7 = RenameFC(fc7) fc8 = mtf.layers.dense(fc7, fc8_units, reduced_dims=fc7.shape.dims[-1:], name='fc8') fc8 = mtf.dropout(fc8, keep_prob) if strategy == 1: assert fc8.shape[-1].name == 'axis1' fc8 = ReplaceMeshWithDuplicates(fc8, meshes[2]) with tf.variable_scope('loss'): if fc8.shape[0] != mtf_labels.shape[0]: fc8 = mt.rename_dimension(fc8, fc8.shape[0].name, mtf_labels.shape[0].name) one_hot_labels = mtf.one_hot(mtf_labels, fc8.shape[-1]) mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( fc8, one_hot_labels, fc8.shape[-1]) mtf_loss = mtf.reduce_mean(mtf_cross_ent) return graph, mesh_to_impl, mtf_loss
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase): @parameterized.parameters( (mtf.Dimension("x", 5), ), (("x", 5), ), ) def testConvertToDimension(self, inputs): dimension = mtf.convert_to_dimension(inputs) self.assertEqual(dimension.name, "x") self.assertEqual(dimension.size, 5) def testConvertToDimensionGenericInputs(self): dimension = mtf.convert_to_dimension(None) self.assertEqual(dimension, None) with self.assertRaises(TypeError): mtf.convert_to_dimension(5) @parameterized.parameters( (mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)]), ), ("x:4;y:8", ), ("x:4.y:8", ), ("x:4 y:8", ), ("x:4,y:8", ), ) def testConvertToShape(self, inputs): shape = mtf.convert_to_shape(inputs) self.assertEqual( shape, mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)])) def testConvertToShapeGenericInputs(self): shape = mtf.convert_to_shape([]) self.assertEqual(shape.dims, []) shape = mtf.convert_to_shape(None) self.assertEqual(shape, None) with self.assertRaises(ValueError): mtf.convert_to_shape("x;4") @parameterized.parameters( (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ), ("d_ff:model;heads:model", ), ("d_ff:model.heads:model", ), ("d_ff:model heads:model", ), ("d_ff:model,heads:model", ), ([("d_ff", "model"), ("heads", "model")], ), ) def testConvertToLayoutRules(self, inputs): layout_rules = mtf.convert_to_layout_rules(inputs) self.assertEqual( layout_rules._pairs, mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs) def testConvertToLayoutRulesGenericInputs(self): with self.assertRaises(ValueError): mtf.convert_to_layout_rules("d_ff;heads") def testTensorLayout(self): tensor_layout = mtf.TensorLayout([0, 2, 1]) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ()) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, )) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2)) tensor_layout = mtf.TensorLayout([None, 0]) self.assertFalse(tensor_layout.is_fully_replicated) tensor_layout = mtf.TensorLayout([None, None, None]) self.assertTrue(tensor_layout.is_fully_replicated) def testGraph(self): graph = mtf.Graph() self.assertEmpty(graph.operations) self.assertEmpty(graph.trainable_variables) self.assertEmpty(graph.all_variables) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertEmpty(graph.trainable_variables) self.assertEmpty(graph.all_variables) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2) def testGraphNames(self): # Standard Usage. graph = mtf.Graph() self.assertEqual(graph.unique_name("a"), "a") self.assertEqual(graph.unique_name("a"), "a_1") self.assertEqual(graph.unique_name("a"), "a_2") # Edge cases, the user may choose the name "a_1". graph = mtf.Graph() self.assertEqual(graph.unique_name("a"), "a") self.assertEqual(graph.unique_name("a"), "a_1") self.assertEqual(graph.unique_name("a_1"), "a_1_1") graph = mtf.Graph() self.assertEqual(graph.unique_name("a"), "a") self.assertEqual(graph.unique_name("a_1"), "a_1") self.assertEqual(graph.unique_name("a"), "a_2") # Case insensitive. graph = mtf.Graph() self.assertEqual(graph.unique_name("a"), "a") self.assertEqual(graph.unique_name("A"), "A_1") @tf.contrib.eager.run_test_in_graph_and_eager_modes() def testLowering(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") inputs = tf.constant(0.) mtf_inputs = mtf.import_tf_tensor(mesh, tf_tensor=inputs, shape=mtf.Shape([])) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_inputs) inputs_value, outputs_value = self.evaluate([inputs, outputs]) self.assertEqual(inputs_value, outputs_value) # Check that methods run without error. _ = lowering.copy_masters_to_slices() _ = lowering.copy_slices_to_masters() def testMesh(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") self.assertEqual(mesh.graph, graph) def testMeshImpl(self): shape = mtf.Shape( [mtf.Dimension("batch", 4), mtf.Dimension("model", 8)]) layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"), ("heads", "model")]) mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules) self.assertEqual(mesh_impl.shape, shape) self.assertLen(shape, mesh_impl.ndims) self.assertEqual(mesh_impl.layout_rules, layout_rules) self.assertEqual(mesh_impl.size, shape.size) self.assertTrue(mesh_impl.supports_control_dependencies) batch = mtf.Dimension("batch", 128) length = mtf.Dimension("length", 500) d_ff = mtf.Dimension("d_ff", 2048) heads = mtf.Dimension("heads", 8) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1) self.assertEqual( mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])), mtf.TensorLayout([0, None, 1]))
def recon_model(mesh, data, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) print("\nfieldvar : \n", fieldvar) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( fieldvar, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #*nc**3 # Total loss diff = (final_field - mtfdata) R0 = tf.constant(R0) print("R0 in the recon_model : ", R0) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior fields = [fieldvar, final_field] metrics = [chisq, prior, loss] return fields, metrics, kv
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list,) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) unique_ids = features["unique_ids"] input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (start_logits, end_logits) = create_model( bert_config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, segment_ids=mtf_segment_ids) if mode == tf.estimator.ModeKeys.TRAIN: def compute_loss(logits, positions): one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim) log_probs = mtf.log_softmax(logits, seq_dim) loss = -mtf.reduce_mean( mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim)) return loss start_positions = features["start_positions"] mtf_start_positions = mtf.import_tf_tensor(mesh, start_positions, [batch_dim]) end_positions = features["end_positions"] mtf_end_positions = mtf.import_tf_tensor(mesh, end_positions, [batch_dim]) start_loss = compute_loss(start_logits, mtf_start_positions) end_loss = compute_loss(end_logits, mtf_end_positions) total_loss = (start_loss + end_loss) / 2.0 _, update_ops = optimization_lib.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, max_optimized_variable_size=FLAGS.max_optimized_variable_size, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) elif mode == tf.estimator.ModeKeys.PREDICT: start_logits = mtf.anonymize(start_logits) end_logits = mtf.anonymize(end_logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) if mode == tf.estimator.ModeKeys.TRAIN: tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) global_step = tf.train.get_global_step() tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = bert_lib.get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( mode, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: predictions = { "unique_ids": unique_ids, "start_logits": lowering.export_to_tf_tensor(start_logits), "end_logits": lowering.export_to_tf_tensor(end_logits), } return tf.estimator.tpu.TPUEstimatorSpec( mode, prediction_hooks=[restore_hook], predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode))
def synthetic_attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, synthesize=True, synthesize_mode="random_plus_alpha", factorized_dim=16, max_length=512, context=None): """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743). key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor synthesize: flag to use synthetic attention or not synthesize_mode: which variant of synthesizer to use factorized_dim: factorized dim for synthesizers max_length: max length of input sequence context: context since we need context mode Returns: Tensor with shape q.shape - key_dim + value_dim """ if synthesize: num_heads = v.shape.get_dim_by_name("heads") tf.logging.info("Using synthesizer") if synthesize_mode == "random": tf.logging.info("Using Random Synthesizers") r_shape = mtf.Shape([mtf.Dimension("length", max_length), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", num_heads, max_length)]) initializer = tf.random_uniform_initializer() r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r r_shape = logits.shape elif synthesize_mode == "factorized": tf.logging.info("Using Factorized Random Synthesizers") k = factorized_dim r1_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r2_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) initializer = tf.random_normal_initializer() r1 = mtf.get_variable(context.mesh, "R1", r1_shape, initializer=initializer, dtype=context.variable_dtype) r2 = mtf.get_variable(context.mesh, "R2", r2_shape, initializer=initializer, dtype=context.variable_dtype) r = mtf.einsum([r1, r2], r_shape) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) r = mtf.slice(r, 0, length_dim.size, length_dim.name) logits = r elif synthesize_mode == "dense_minus": # Dense Synthesizer Model tmp_dim = mtf.Dimension("memory_length", max_length) logits = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) logits = mtf.slice(logits, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") logits = mtf.slice(logits, 0, length_dim.size, "length") elif synthesize_mode == "random_plus_alpha": # Mixture Random Synthesizer with learnable Alpha tf.logging.info("Using Random Plus Alpha") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) num_heads = logits.shape.get_dim_by_name("heads") r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, length_dim.name) alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) elif synthesize_mode == "dense_plus_alpha": # Mixture Dense Synthesizer with learnable alpha tf.logging.info("Using Dense Plus Alpha Scaling") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) tmp_dim = mtf.Dimension("memory_length", 512) r = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) if dropout_rate != 0.0: weights = mtf.dropout( weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) if synthesize and "plus" not in synthesize_mode: if synthesize_mode == "dense_minus": outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim]) else: outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim]) else: outputs_shape = q.shape - [key_dim] + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def cifar_model(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ features = copy.copy(features) batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 8) cols_dim = mtf.Dimension("cols_size", 8) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 3) image = features['image'] image = bnorm(image) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # Add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 7) fw_dim = mtf.Dimension("fw", 7) filters1_dim = mtf.Dimension("filters1", 32) filters2_dim = mtf.Dimension("filters2", 32) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu( mtf.conv2d_with_blocks(x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu( mtf.conv2d_with_blocks(f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # Add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h2, classes_dim, name="logits") labels = features['label'] if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def testPool(self, pooling_method): batch = 2 depth = 3 height = 4 width = 6 channels = 3 tf.random.set_random_seed(1234) inputs = tf.random_normal([batch, depth, height, width, channels]) stride_d = 3 stride_h = 2 stride_w = 3 graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) depth_dim = mtf.Dimension("depth", depth) height_dim = mtf.Dimension("height", height) width_dim = mtf.Dimension("width", width) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape([ batch_dim, depth_dim, height_dim, width_dim, channels_dim ])) if pooling_method == "MAX_2D": mtf_outputs = mtf.layers.max_pool2d(mtf_inputs, ksize=(stride_h, stride_w)) inputs = tf.reshape(inputs, [batch * depth, height, width, channels]) expected_outputs = tf.keras.layers.MaxPooling2D( (stride_h, stride_w))(inputs) expected_outputs = tf.reshape(expected_outputs, [ batch, depth, int(height / stride_h), int(width / stride_w), channels ]) elif pooling_method == "AVG_2D": mtf_outputs = mtf.layers.avg_pool2d(mtf_inputs, ksize=(stride_h, stride_w)) inputs = tf.reshape(inputs, [batch * depth, height, width, channels]) expected_outputs = tf.keras.layers.AveragePooling2D( (stride_h, stride_w))(inputs) expected_outputs = tf.reshape(expected_outputs, [ batch, depth, int(height / stride_h), int(width / stride_w), channels ]) elif pooling_method == "MAX_3D": mtf_outputs = mtf.layers.max_pool3d( mtf_inputs, ksize=[stride_d, stride_h, stride_w]) expected_outputs = tf.keras.layers.MaxPooling3D( [stride_d, stride_h, stride_w])(inputs) elif pooling_method == "AVG_3D": mtf_outputs = mtf.layers.avg_pool3d( mtf_inputs, ksize=[stride_d, stride_h, stride_w]) expected_outputs = tf.keras.layers.AveragePooling3D( [stride_d, stride_h, stride_w])(inputs) mtf_gradient = mtf.gradients([mtf_outputs], [mtf_inputs])[0] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) actual_gradient = lowering.export_to_tf_tensor(mtf_gradient) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllClose(actual, expected) actual = self.evaluate(actual_gradient) if pooling_method == "MAX_2D": expected_non_zeros = batch * depth * height * width * channels / ( stride_h * stride_w) self.assertEqual(np.count_nonzero(actual), expected_non_zeros) elif pooling_method == "AVG_2D": expected = np.ones((batch, depth, height, width, channels), dtype=np.float32) / stride_h / stride_w self.assertAllClose(actual, expected) elif pooling_method == "MAX_3D": expected_non_zeros = batch * depth * height * width * channels / ( stride_d * stride_h * stride_w) self.assertEqual(np.count_nonzero(actual), expected_non_zeros) elif pooling_method == "AVG_3D": expected = np.ones( (batch, depth, height, width, channels), dtype=np.float32) / stride_d / stride_h / stride_w self.assertAllClose(actual, expected)
def test_get_laidout_tensors(self, is_eval_mode): mesh_shape = "mesh_x:2, mesh_y:1" layout = "batch:mesh_x, io:mesh_y" batch_io_dim = 4 with tf.Session() as sess: topology, num_cores = self.initialize_system(sess) # Get a device_assignment object for mtf. d_assignment = device_assignment.device_assignment( topology, computation_shape=[1, 1, 1], num_replicas=num_cores) # Hacked dataset creator: creates different datasets for the first and # second call, in order to test SimdMeshImplInputReader. self.sub_batch_created_times = 0 def stateful_ds_creator(): whole_batch = tf.eye(batch_io_dim, dtype=tf.float32) sub_batch = tf.slice(whole_batch, [self.sub_batch_created_times * 2, 0], [2, 4]) self.sub_batch_created_times += 1 return tf.data.Dataset.from_tensors( sub_batch).repeat().unbatch() batch_dim = mtf.Dimension("batch", batch_io_dim) io_dim = mtf.Dimension("io", batch_io_dim) mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])] # Get mesh_impl. mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, None, d_assignment) simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_impl, stateful_ds_creator, mtf_input_shapes, external_worker=False, is_eval_mode=is_eval_mode) def model_fn(features): return features replicated_computation = tpu.replicate( computation=model_fn, inputs=[[]] * num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=d_assignment) simd_input_reader.start_infeed_thread(sess, 1) results = sess.run(replicated_computation) print("results: {}".format(results)) core_0_data = results[0][0] core_1_data = results[1][0] print("core_0_data: {}".format(core_0_data)) print("core_1_data: {}".format(core_1_data)) if is_eval_mode: # If there is only one dataset object, then the stateful_ds_creator() # should be called only once. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_1_data) else: # If there are two dataset objects, then the stateful_ds_creator() # should be called twice. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32), core_1_data) sess.run(tf.tpu.shutdown_system())
def _maybe_reshape_attention_input_for_2d_sharding(context, q, k, v, bias, unsplittable_dims): """Reshape the inputs to attention to split over an unused mesh dimension. In the case where the attention computation is unnecessarily replicated, this function reshapes the attention inputs to remove the unnecessary replication. This becomes relevent when doing 2-dimenional model parallelism. d_model is sharded over one mesh dimension and [vocab, num_heads, d_ff] are sharded over the other mesh dimension. This fully distributes all of the einsum operations, except for the internals of the attention computation. To distribute that computation, this function creates a new tensor-dimension from the low bits of either the batch dimension or the num_heads dimension, and then splits that dimension over the unused mesh dimension. Args: context: a transformer.Context q: a Tensor k: a Tensor v: a Tensor bias: a Tensor unsplittable_dims: a list of tensor-dimensions not to split. The key/value dimensions should be passed here. Returns: reshaped_q: a Tensor reshaped_k: a Tensor reshaped_v: a Tensor reshaped_bias: a Tensor """ original_inputs = q, k, v, bias # we need to know the layout and mesh-shape to figure out what to do. if not context or not context.model.layout or not context.model.mesh_shape: return original_inputs mesh_shape = mtf.convert_to_shape(context.model.mesh_shape) layout_rules = mtf.convert_to_layout_rules(context.model.layout) # find a mesh dim that is unused (no tensor-dimension is split across it) mesh_axis_used = [False] * mesh_shape.ndims for x in original_inputs: for mesh_axis in layout_rules.tensor_layout( x.shape, mesh_shape).tensor_axis_to_mesh_axis: if mesh_axis is not None: mesh_axis_used[mesh_axis] = True if False not in mesh_axis_used: return original_inputs mesh_dim = mesh_shape.dims[mesh_axis_used.index(False)] # Choose an appropriate name for the new tensor-dimension so that the layout # will know to split it across the unused mesh dimension. tensor_dim_name = None tensor_dim_name = layout_rules.mesh_dimension_name_to_tensor_dimension_names( mesh_dim.name) if tensor_dim_name: tensor_dim_name = tensor_dim_name[0] else: return original_inputs # Find a tensor-dimension that we can further split, by breaking off the # lower bits into our new tensor-dimension. # This resplittable tensor-dimension must be presnent in all of q, k, v # and must be large enough to be further split. resplittable_dim = None for d in q.shape.dims: if d in k.shape.dims and d in v.shape.dims and d not in unsplittable_dims: num_splits = mtf.tensor_dim_to_mesh_dim_size( context.model.layout, context.model.mesh_shape, d) if d.size % (num_splits * mesh_dim.size) == 0: resplittable_dim = d break if not resplittable_dim: return original_inputs new_dim_high = mtf.Dimension(resplittable_dim.name, num_splits) new_dim_low = mtf.Dimension(tensor_dim_name, resplittable_dim.size // num_splits) def _my_reshape(x): if x and resplittable_dim in x.shape.dims: return mtf.replace_dimensions(x, resplittable_dim, [new_dim_high, new_dim_low]) else: return x return _my_reshape(q), _my_reshape(k), _my_reshape(v), _my_reshape(bias)
def vocab_dim(self): # pad vocab to a multiple of 128 so as to be splittable. # TODO(noam): This creates issues in checkpoint compatibility n = self.config.vocab_size return mtf.Dimension("vocab", n + (-n % 128))
def hybrid_attention(q, k, v, context, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor context: context of the attention layer. memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor Returns: Tensor with shape q.shape - key_dim + value_dim """ logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias query_length_dim = mtf.Dimension("length", memory_length_dim.size) doubly_coeff = mtf.get_variable(context.mesh, "doubly_coeff", [], initializer=tf.constant_initializer(0.5), dtype=context.variable_dtype) doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.) upper_weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) lower_log_weights = mtf.log_softmax(logits, query_length_dim, extra_logit=extra_logit) doubly_weights = mtf.softmax(lower_log_weights, memory_length_dim, extra_logit=extra_logit) weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights if dropout_rate != 0.0: weights = mtf.dropout(weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def token_type_vocab_dim(self): return mtf.Dimension("token_type_vocab", self.config.type_vocab_size)
def gradient_based_subword_tokenization(x, length_dim, max_subword_length=4, downsample=None, use_offsets=False, consider_chars_as_blocks=False, use_block_pos_embedding=False, share_block_kernel=False, memory_embeddings=0, context=None, block_mixing_mode=None, activation="softmax", downsample_function="mean"): """Implements GBSWT from Charformer. Args: x: a Tensor containing length_dim length_dim: a Dimension max_subword_length: integer downsample: integer. use_offsets: boolean. consider_chars_as_blocks: boolean. use_block_pos_embedding: boolean. share_block_kernel: boolean. memory_embeddings: integer. context: Context. block_mixing_mode: Str for block mixing. activation: Str for block ranking. downsample_function: Str, supports mean/linformer for now. Returns: a Tensor with the same shape as x. Raises: ValueError: if channels or depth don't match. """ # don't use this for now. del max_subword_length del memory_embeddings all_blocks = [] all_scores = [] tf.logging.info("GSW block layer") def _tile(x, n, tile_dim): # Simple tile function in MTF. return mtf.concat([x] * n, tile_dim.name) def _repeat(x, n, repeat_dim): # repeat function in MTF tmp_dim = mtf.Dimension("tmp", 1) expand_shape = mtf.Shape(x.shape.dims + [tmp_dim]) x = mtf.reshape(x, expand_shape) x = _tile(x, n, tmp_dim) output_shape = [] for dim in x.shape.dims: if dim.name == "tmp": continue if dim.name == repeat_dim.name: dim = mtf.Dimension(dim.name, dim.size * n) output_shape.append(dim) output_shape = mtf.Shape(output_shape) x = mtf.reshape(x, output_shape) return x def _combined_dim(dims): return mtf.Dimension(dims[0].name, mtf.Shape(dims).size) # compute all subword blocks # TODO(yitay): handle offsets to get all blocks if activation == "sigtanh": # one score for sigmoid tmp_dim = mtf.Dimension("block_score", 2) else: tmp_dim = mtf.Dimension("block_score", 1) model_dim = x.shape[-1] subword_blocks_width = [2, 3, 4] if consider_chars_as_blocks: subword_blocks_width += [1] if share_block_kernel: block_kernel_shape = mtf.Shape([model_dim, tmp_dim]) block_kernel = mtf.get_variable( x.mesh, "block_kernel", block_kernel_shape, initializer=None, dtype=context.variable_dtype) else: block_kernel = None for subword_len in subword_blocks_width: if use_block_pos_embedding: # this is turn off by default. It is meant to support cases like # parameterized pooling or other features. block_len_dim = mtf.Dimension(length_dim.name, subword_len) # TODO(vqtran): Consider other positional embeddings. block_pos_emb = sinusoid_positional_embedding_weights( context.mesh, block_len_dim, x.shape[-1], context.variable_dtype.activation_dtype) block_pos_emb = _repeat(block_pos_emb, math.ceil(length_dim.size / float(subword_len)), block_len_dim) if use_offsets: offset_space = subword_len else: offset_space = 1 for offsets in range(offset_space): if offsets > 0: xoff = mtf.shift(x, offsets, length_dim, wrap=False) if use_block_pos_embedding: block_pos_emb = mtf.shift( block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False) else: xoff = x tf.logging.info("SW len=%d offset=%d", subword_len, offsets) if length_dim.size % subword_len != 0: tf.logging.info("Not divisible by length") # add extra padding tokens pad_amt = int(subword_len) - int( length_dim.size % subword_len) kp = mtf.pad(xoff, [0, pad_amt], length_dim.name) else: kp = xoff if use_block_pos_embedding: kp += block_pos_emb bx = mtf.pool_tensor_1d( kp, pool_dim=kp.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(subword_len)) block_score = mtf.layers.dense( bx, [tmp_dim], use_bias=False, name="bx", reduced_dims=[model_dim], variable_dtype=None, kernel_weights=block_kernel) expand_bx = _repeat(bx, subword_len, length_dim) expand_scores = _repeat(block_score, subword_len, length_dim) if offsets > 0: # add offset. expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name) expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name) new_len = expand_bx.shape.get_dim_by_name(length_dim.name) if new_len.size < length_dim.size: pad_amt = new_len.size - length_dim.size expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name) expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name) elif new_len.size > length_dim.size: expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name) expand_scores = mtf.slice(expand_scores, 0, length_dim.size, length_dim.name) new_tmp_dim = mtf.Dimension("extra_dim", 1) expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim]) expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim]) expand_bx = mtf.reshape(expand_bx, expand_shape) expand_scores = mtf.reshape(expand_scores, expand_scores_shape) all_blocks.append(expand_bx) all_scores.append(expand_scores) all_blocks = mtf.concat(all_blocks, new_tmp_dim.name) all_scores = mtf.concat(all_scores, new_tmp_dim.name) tf.logging.info(all_blocks) new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim") combined_dim = _combined_dim([new_tmp_dim, tmp_dim]) block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim block_net = mtf.reshape(all_scores, block_net_shape) if block_mixing_mode == "score_attention": tf.logging.info("Using score attention") att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim]) tf.logging.info(block_net) att = mtf.softmax(att, reduced_dim=att.shape[-1]) block_net = mtf.einsum([att, block_net], output_shape=block_net.shape) tf.logging.info(block_net) if activation == "softmax": block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim) elif activation == "tanh": tf.logging.info("Using tanh") block_net = mtf.tanh(block_net) all_blocks = block_net * all_blocks all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim) output = all_blocks if downsample: output_length = output.shape.get_dim_by_name("length") if output_length.size % int(downsample) != 0: pad_amt = int(downsample) - int(output_length.size % int(downsample)) output = mtf.pad(output, [0, pad_amt], output_length.name) if downsample_function == "mean": output = mtf.pool_tensor_1d( output, pool_dim=output.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(downsample)) else: raise ValueError("Downsampling function not implemeneted.") return output
def num_heads_dim(self): return mtf.Dimension("num_heads", self.config.attention_num_heads)
def _combined_dim(dims): return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)
def key_dim(self): """Dimensionality of attention key.""" if self.config.attention_key_head_size is None: raise ValueError("The key head size is not defined.") return mtf.Dimension("d_k", self.config.attention_key_head_size)
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, fully_autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, write_priority=None, read_priority=None, attention_kwargs=None, context=None): """Attention to the a neighborood around the source. If fully_autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not fully_autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. In addition, if write_priority and read_priority are provided, then attention is limited to position pairs where read_priority[query position] >= write_priority[memory position] Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) fully_autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer write_priority: an optional Tensor containing length_dim read_priority: an optional Tensor containing length_dim attention_kwargs: optional keyword arguments for attention() context: optional context. Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) tf.logging.info(attention_kwargs) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if fully_autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos visible = mtf.equal(q_sequence_id, m_sequence_id) visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius)) visible = mtf.logical_and(visible, mtf.less_equal( relative_position, 0 if fully_autoregressive else radius)) if read_priority is not None: write_priority = _reshape_memory(write_priority) read_priority = _reshape_query(read_priority) visible = mtf.logical_and( visible, mtf.greater_equal(read_priority, write_priority)) bias = attention.visibility_mask_to_attention_bias(visible, q.dtype) o = attention.attention(q, k, v, padded_memory_block_length, key_dim, value_dim, bias, context=context, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
def value_heads_dims(self): """Dimensionality of number of value heads.""" if self.config.attention_num_value_heads is None: raise ValueError("The number of value heads is not defined.") return mtf.Dimension("value_heads", self.config.attention_num_value_heads)
def lpt_prototype(mesh, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 0 lnc = nc // 2**downsampling_factor # fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) # # Reshaping array into high resolution mesh # field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1), # [initc], # output_dtype=tf.float32, # output_shape=hr_shape, # name='my_reshape', # splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3]) # state = mtfpm.lpt_init_single( initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return initc, final_field
def transformer_moe_layer_v2( inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None): """2-level mixture of experts. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_capacity_factor_second_level: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() One set of params for experts in first level and different of hparams per expert in the second level. The number of parameters in the gating network is: (input_dim.size * (hparams.num_experts) + (moe_hidden_size * hparams.num_experts) * hparams.num_experts The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-3 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: a, b: batch size l: original sequence length m: input depth n: output depth g, h: number of groups s, t: group size x, y: number of experts c, d: expert capacity input: [a0, b1, l, m] input: [a0, g1, s, m] dispatch_tensor_x: [a0, g1, s, x, c] expert_input: [a0, g1, x, c, m] alltoall: [a0, g, x1, c, m] alltoall: [a0, g, x1, c, m] transpose: [x1, a0, g, c, m] reshape: [x1, h0, s, m] assignment2: [x1, h0, t, y, d] expert_input2: [x1, h0, y, d, m] alltoall: [x1, h, y0, d, m] ... reverse of that gating params 0: [m, x] gating params 1: [x1, m, y] expert params: [x1, y0, m, hidden] [x1, y0, hidden, n] Args: inputs: a mtf.Tensor with shape [a, b, l, m] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional mtf.Tensor with shape [a, b, l] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). Returns: outputs: a Tensor with shape [a, b, l, n] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1], dtype=inputs.dtype) + nonpadding insert_outer_batch_dim = (len(inputs.shape.dims) == 3) if insert_outer_batch_dim: inputs = mtf.reshape( inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims) assert len(hparams.moe_num_experts) == 2 a0, b1, l, m = inputs.shape.dims hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0]) y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1]) x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0]) y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1]) n = output_dim # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (g.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( b1.size * l.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1)) g1 = mtf.Dimension(b1.name, num_groups) g = mtf.Dimension(b1.name + "_unsplit", g1.size) s = mtf.Dimension("group_size_x", group_size) # Each sequence sends (at most?) expert_capacity positions to each expert. # Static expert_capacity dimension is needed for expert batch sizes if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size)) expert_capacity = max(expert_capacity, 4) c = mtf.Dimension("expert_capacity_x", expert_capacity) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (h.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( a0.size * g.size * c.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0)) t = mtf.Dimension("group_size_y", group_size) h0 = mtf.Dimension(a0.name, num_groups) h = mtf.Dimension(a0.name + "_unsplit", h0.size) expert_capacity = min( t.size, int((t.size * hparams.moe_capacity_factor_second_level) / y.size)) expert_capacity = max(expert_capacity, 4) d = mtf.Dimension("expert_capacity_y", expert_capacity) # First level of expert routing # Reshape the inner batch size to a multiple of group_dim g1 and # group_size_dim s. inputs = mtf.reshape(inputs, [a0, g1, s, m]) if nonpadding is not None: nonpadding = mtf.reshape(nonpadding, [a0, g1, s]) # Get the assignments for the first level. # dispatch_tensor_x has shape [a0, g1, s, x, c] if hparams.moe_gating == "top_2": dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=x, expert_capacity_dim=c, hparams=hparams, train=train, variable_dtype=variable_dtype, name="outer_gating", importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m]) # we construct an "importance" Tensor for the inputs to the second-level # gating. The importance of an input is 1.0 if it represents the # first-choice expert-group and 0.5 if it represents the second-choice expert # group. This is used by the second-level gating. importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c]) importance = 0.5 * ( mtf.to_float(mtf.greater(importance, 0.5)) + mtf.to_float(mtf.greater(importance, 0.0))) # First level, all to all. Here we change the split dimension from g1 to x1. expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape( [x1, a0, g, c, m])) importance = mtf.reshape(importance, [x1, a0, g, c]) # Second level of expert routing # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0 # and group_size_dim t. inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m]) importance = mtf.reshape(importance, [x1, h0, t]) # Get the assignments for the second level. # dispatch_tensor_y has shape [x1, h0, t, y, d] if hparams.moe_gating == "top_2": dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating( inputs=inputs_y, outer_expert_dims=[x1], experts_dim=y, expert_capacity_dim=d, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=importance, name="inner_gating") else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m]) # Second level, all to all. Here we change the split dimension from h0 to y0. expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape( [y0, x1, h, d, m])) hidden_output = mtf.layers.dense( expert_inputs_y, hidden_dim, expert_dims=[y0, x1], activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype, name="wi") expert_output = mtf.layers.dense( hidden_output, output_dim, expert_dims=[y0, x1], use_bias=False, variable_dtype=variable_dtype, name="wo") # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done) # expert_output has shape [y0, x1, h, d, n] # alltoall expert_output = mtf.reshape(expert_output, mtf.Shape( [y, x1, h0, d, n])) # combine results from inner level output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n]) # Reshape the combined tensor from inner level to now contain outer_batch_dim # a0 and group_dim g output = mtf.reshape(output_y, [x1, a0, g, c, n]) # alltoall from expert_dim x to group_dim g1 expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n])) # combine results from outer level output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n]) # Reshape the combined tensor to now contain inner_batch_dim # b1 and the original sequence length output = mtf.reshape(output_x, [a0, b1, l, n]) if insert_outer_batch_dim: output = mtf.reshape(output, [b1, l, n]) return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
def fn(x): with tf.variable_scope(scope): nx = x.shape[-1] # Grab last dimension from input if use_rezero: prenorm = identity elif use_scale_norm: prenorm = scale_norm else: prenorm = layer_norm pre_residual_fn = rezero if use_rezero else identity attention_type = params["attention_types"][layer_num] if macaron_attention: mult = 0.5 mlp_fn = mlp_glu if use_mlp_glu else mlp intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2) # Define intermediate layer of mlp - to split dim_intermediate_expanded = mtf.Dimension( "intermediate_expanded", intermediate_size) m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) x = x + (m * mult) else: mult = 1 if attention_type != "none": res_x = prenorm(x, "norm_1", variable_dtype=variable_dtype, params=params) a = attn(res_x, "attn", nx, attention_type=attention_type, params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim, variable_dtype=variable_dtype, context=context) else: a = x x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype) res_x = prenorm(x, "norm_2", variable_dtype=variable_dtype, params=params) if use_moe: moe_params = mtf.transformer.moe.HParams() mtf.transformer.moe.set_default_moe_hparams(moe_params) moe_params.add_hparam("moe_min_expert_capacity", 1) moe_params.add_hparam("moe_use_experts_attention", False) # Override defaults for k, v in params["moe_params"].items(): moe_params.add_hparam(k, v) moe_train = params["mode"] == "train" m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1( res_x, x.shape[-1], moe_params, train=moe_train, mesh_shape=params["mesh_shape"], layout=params["layout"], activation=params.get("moe_activation", "relu"), variable_dtype=variable_dtype, num_microbatches=params["num_microbatches"]) m = mtf.dropout(m, rate=params["res_dropout"], name="moe_dropout") else: mlp_fn = mlp_glu if use_mlp_glu else mlp intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2) # Define intermediate layer of mlp - to split dim_intermediate_expanded = mtf.Dimension( "intermediate_expanded", intermediate_size) m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype) x = x + pre_residual_fn( (m * mult), "norm_rezero_2", variable_dtype) return x, aux_loss